How to generate text?
๐ค transformers๋ฅผ ํ์ฉํ์ฌ language generation์ ๋ค๋ฅธ decoding methods๋ฅผ ์ฌ์ฉํด๋ณด์!
- Introduction
- Using different decoding methods for language generation with Transformers
- GenerationMixin ๋ฏ์ด๋ณด๊ธฐ
Introduction
์ต๊ทผ ๋ช ๋
๋์ OpenAI์ GPT-2 ๋ชจ๋ธ๊ณผ ๊ฐ์ด ์๋ฐฑ๋ง ๊ฐ์ ์น ํ์ด์ง์์ ํ๋ จ๋ large-scale transformer ๊ธฐ๋ฐ ์ธ์ด ๋ชจ๋ธ์ ๋ฑ์ฅ์ผ๋ก open-ended language generation
์ ๋ํ ๊ด์ฌ์ด ๋์์ก์ต๋๋ค. Conditioned open-ended language generation์ ๊ฒฐ๊ณผ๋ ๊ต์ฅํ ์ธ์์ ์
๋๋ค.
2017๋ ์ transformer๊ฐ ๋ฑ์ฅํ ์ด๋๋ก ์ํคํ ์ณ๋ฅผ ์์ ํ๋ ๋ฐฉ๋ฒ ๋ฐ ๋ฐฉ๋ํ unsupervised ํ์ต ๋ฐ์ดํฐ(self-supervised๋ฅผ ์ํ)๋ค์ด ์์ง๋ง ๋ ๋์ ๋์ฝ๋ฉ ๋ฐฉ๋ฒ(better decoding methods) ๋ํ ์ค์ํ ์ญํ ์ ํ์ต๋๋ค.
์ด ๋ธ๋ก๊ทธ ํฌ์คํธ๋ ๋ค์ํ ๋์ฝ๋ฉ ์ ๋ต์ ๋ํ ๊ฐ๋ตํ ๊ฐ์๋ฅผ ์ ๊ณตํ๊ณ ๋ ์ค์ํ ๊ฒ์ ์ ๋ช ํ ๐ค transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ์์ฃผ ์ ์ ๋ ธ๋ ฅ์ผ๋ก ์ด๋ฅผ ๊ตฌํํ ์ ์๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค!
- jinmang2: ๋ํ ์ ๋ธ๋ก๊ทธ ํฌ์คํ ์์ ์ฌ๋ฌ๋ถ๋ค์ ๐ค transformers์์ Decoding methods๊ฐ ์ด๋ป๊ฒ ๊ตฌํ๋์ด ์๋์ง๋ ์ ์ ์์ต๋๋ค.
๋ณธ ๊ฒ์๊ธ์์ ๋ค๋ฃจ๋ ๋ชจ๋ ๊ธฐ๋ฅ๋ค(functionalities)์ ๋ชจ๋ auto-regressive language generation(๋ ์์ธํ ๋ด์ฉ์ Jay Alammar๋์ ํฌ์คํ
์ ํ์ธํด์ฃผ์ธ์)์ ์ฌ์ฉํ ์ ์์ต๋๋ค. auto-regressive
๋ ๊ฐ๋จํ ๋งํด ์๋ ๊ฐ์ ์ ๊ธฐ๋ฐ์ผ๋ก ๋ ๋ฐฉ๋ฒ๋ก ์
๋๋ค.
Assumption
==========
The probability distribution of a word sequence can be decomposed into
the product of conditional next word distributions.
word sequence, ์ฆ, ๋จ์ด๋ก ์ด๋ฃจ์ด์ง ์์ด(๋ฌธ์ฅ์ด๋ผ๊ณ ์ดํดํ์๋ฉด ๋ฉ๋๋ค)์ ๋ค์ ๋จ์ด๊ฐ ๋์ฌ ์กฐ๊ฑด๋ถ ๋ถํฌ๋ค์ ๊ณฑ์ผ๋ก ๋ถํด๊ฐ ๊ฐ๋ฅํ๋ค๋ผ๋ ์ ์ ๋ฅผ ๊น๊ณ ์์ต๋๋ค.
์์์ผ๋ก ๋ณด๊ฒ ์ต๋๋ค.
$$P(w_{1:T}|W_0)=\prod_{t=1}^{T}P(w_t|w_{1:t-1},W_0),\;\text{with}\, w_{1:0}=\emptyset$$
- $W_0$์ initial context word sequence
- $T$๋ ๋ฌธ์ฅ์ ๊ธธ์ด์
๋๋ค.
- patrick๋ ํฌ์คํ ๋ณธ๋ฌธ์์๋ ์์ฑ๋ ๋ฌธ์ฅ์ ๊ธธ์ด๋ฅผ ์๋ฏธํ์๋ ๊ฒ ๊ฐ์ต๋๋ค. ์์ฑ์ ๋จ์ด(ํน์ ํ ํฐ) ๋จ์๋ก ๊ธธ์ด๊ฐ 1์ฉ ๋์ด๋๊ธฐ ๋๋ฌธ์ on-the-fly๋ผ๋ ํํ์ ์ฌ์ฉํ์ ๊ฒ ๊ฐ์ต๋๋ค.
- Timestep $t$๊ฐ $T$๊ฐ ๋๋ฉด? ๋ฌธ์ฅ์ ์ ๋ถ ๋ถํดํ๋ค๋ ๋ป์ด๊ฒ ์ฃ ?(conditional next word dists.๋ก) ์ด ๋๋ $P(w_t|w_{1:t-1},W_0)$์์
EOS
ํ ํฐ์ด ์์ฑ๋ฉ๋๋ค.-
EOS
๋ End of Sentence(ํน์ Sequence)์ ์ฝ์๋ก ๋ฌธ์ฅ์ ๋์ ์๋ฏธํฉ๋๋ค.
-
์ด์ ๊ฐ์ฅ ์ค์ํ ๋ค ๊ฐ์ง decoding ๋ฐฉ๋ฒ์ ๋ํ์ฌ ์๊ฐํด๋๋ฆฌ๋๋ก ํ๊ฒ ์ต๋๋ค.
๋ธ๋ก๊ทธ ํฌ์คํ ์ ์์๋ฅผ ๋ณด์ ๋ ์ข๊ณ ์๋ฌด๋๋ ํ๊ตญ์ด๋ก ๋ฒ์ญ๋ ํฌ์คํ ์ด๊ธฐ ๋๋ฌธ์ ํ๊ตญ์ด ์ธ์ด ๋ชจ๋ธ๋ก ์์๋ฅผ ๋ค์ด๋๋ฆฌ๋ ๊ฒ์ด ๋ณด๊ธฐ ์ข์ ๊ฒ ๊ฐ์ต๋๋ค. ์ด์ ๋ํ ์ธํ ์ ์ํํ์ฃ .
- tensorflow๋ ์ฌ์ฉํ์ง ์์ต๋๋ค.
!pip install -q git+https://github.com/huggingface/transformers.git
์์์์ ์ฌ์ฉํ ๋ชจ๋ธ์ SKT์์ ๊ฐ๋ฐํ KoGPT2์ด๋ฉฐ ์์ธํ ์ค๋ช ์ ์๋ ๋งํฌ๋ฅผ ์ฐธ๊ณ ํด์ฃผ์ธ์.
๊ฐ๋ตํ ์๊ฐ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- Vocab size: 51,200
- ์ด๋ชจ์ง, ์ด๋ชจํฐ์ฝ ๋ฑ์ ์ถ๊ฐํ์ฌ ํด๋น ํ ํฐ์ ์ธ์ ๋ฅ๋ ฅ ๊ฐ์
- unused token์ 100๊ฐ ์ฌ์ฉํ์ฌ ํ์ํ task์ ๋ฐ๋ผ ์์ ๋กญ๊ฒ ์ ์ ๊ฐ๋ฅ
- metaspace๋
โ
Model | # of params | Type | # of layers | # of heads | ffn_dim | hidden_dims |
---|---|---|---|---|---|---|
kogpt2-base-v2 |
125M | Decoder | 12 | 12 | 3072 | 768 |
- ์ฌ์ฉํ ๋ฐ์ดํฐ๋ ํ๊ตญ์ด ์ํค ๋ฐฑ๊ณผ, ๋ด์ค, ๋ชจ๋์ ๋ง๋ญ์น v1.0, ์ฒญ์๋ ๊ตญ๋ฏผ์ฒญ์ ๋ฑ ๋ค์ํ ๋ฐ์ดํฐ
from transformers import AutoTokenizer, AutoModelForCausalLM
model_path_or_name = "skt/kogpt2-base-v2"
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name, force_download=True)
model = AutoModelForCausalLM.from_pretrained(model_path_or_name, force_download=True)
tokenizer # Rust๋ก ๊ตฌํ๋ `Fast`ํ tokenizer
tokenizer.tokenize("๊ทผ์ก์ด ์ปค์ง๊ธฐ ์ํด์๋")
model.__class__.__name__ # AutoModelForCausalLM์ผ๋ก GPT2์ CLM class ํธ์ถ
num_of_parameters = sum(p.numel() for n, p in model.named_parameters())
print(f"{num_of_parameters}") # 125M
LMHeadModel(Causal-LM์ ํ๊ธฐ ์ํ ๋ชจ๋ธ)์ ํฌ๊ฒ ์ธ ํํธ๋ก ๋๋์ด์ ธ ์์ต๋๋ค.
(1) word token/position embedding
- ์ธ์ฝ๋ฉ๋ word sequence์ ๋ํด embedding value๋ฅผ ๊ณ์ฐํฉ๋๋ค.
- ๋จ์ด(ํน์ ํ ํฐ)์ ๋ป์ word token embedding์ผ๋ก,
- ๋จ์ด(ํน์ ํ ํฐ)์ ์์น๋ฅผ word position embedding์ผ๋ก ๋ฒกํฐํํด์ค๋๋ค.
- position embedding์ ๊ฒฝ์ฐ layer์ ํด๋น ์ ๋ณด๋ฅผ ๋ฃ์ด์ฃผ๋ ๊ฒฝ์ฐ๋ ์์ง๋ง (relative position embedding) ํด๋น ํฌ์คํ ์ ๋ฒ์ฃผ๋ฅผ ๋์ด์๊ธฐ ๋๋ฌธ์ ํฅํ ์๊ฐํ๋๋ก ํ๊ฒ ์ต๋๋ค.
Note that: ์๋ ์ฝ๋๋ GPT2 script์์ ๋ฐ์ทํ ์ฝ๋์ ๋๋ค! wte, wpe๋ฅผ ๊ตฌํ๋ ๊ฒ์ model by model์ด์์!"
word_sequence = "๊ทผ์ก์ด ์ปค์ง๊ธฐ ์ํด์๋"
inputs = tokenizer(word_sequence, return_tensors="pt")
input_ids = inputs["input_ids"]
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input_ids
์ค์ ๋ก ๋ชจ๋ธ ์ธํ์ ์ด๋ป๊ฒ ๋ค์ด๊ฐ๋ ํ์ธํด๋ณด์ฃ .
import inspect
# input_ids, attention_mask, token_type_ids, position_ids๊ฐ ์ค์ํด์
# forward์ __call__์ ๊ด๊ณ๋ `torch.nn.Module`์ ์์๋ฐ์์ ๊ทธ๋์
# ์ด๊ฑด ๋ค์ ํ์ต ๊ธฐํ๋ก!
inspect.signature(model.transformer.forward).parameters
์์์ ํ์ธํ ๊ฒ ์ฒ๋ผ position type ids๋ฅผ ์ง์ ์ ๋ ฅ์ ๋ฃ์ด์ค ์๋ ์์ง๋ง ์ด๋ฒ์ ์ง์ ๋ง๋ค์ด์ค๊ฒ์! (์ค์ ๋ก source code์์ position_ids๊ฐ None์ด๋ฉด ์๋์ฒ๋ผ ๋ง๋ค์ด์ค์)
import torch
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids # ๋ค ๊ฐ์ ํ ํฐ์ ๋ํ ์์น ์ ๋ณด
Word token embedding์ ๊ฒฝ์ฐ vocab์ ์๋งํผ vector๊ฐ ์ ์๋์ด ์์ด์ผ ํฉ๋๋ค.
ํ์ง๋ง Word position embedding์ ๊ฒฝ์ฐ tokenizing์ ๊ฒฐ๊ณผ๋ก ๋์จ ํ ํฐ์ ์๋ก ๋งคํ์ด ๋๊ธฐ ๋๋ฌธ์ ๋ฏธ๋ฆฌ max_length๋ฅผ ์ ํด๋ฌ์. KoGPT2์ ๊ฒฝ์ฐ์ 1,024๋ค์!
์์ ๊ฐ์ ์ด์ ๋ก wte์ wpe์ matrix shape์ ๋ค๋ฆ ๋๋ค!
- GPT2๋ absolute position embedding์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ด์์!
- Transformer์ SInusoidal encoding์ ์ฌ์ฉํ๋ฉด extrapolate๋ฅผ ํ ์ ์๊ธฐ ๋๋ฌธ์ ์ ๋ฐ ์์น ๊ณ ์ ๋ฌธ์ ๋ ์๊ธฐ์ง ์๊ฒ ์ฃ !
(
model.transformer.wte, # vocab_size X hidden_dim
model.transformer.wpe, # max_position_length X hidden_dim
)
inputs_embeds = model.transformer.wte(input_ids)
position_embeds = model.transformer.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
hidden_states.shape # (batch_size, sequence length, hidden_dim)
print(f"n_layers: {len(model.transformer.h)}")
for i, block in enumerate(model.transformer.h):
outputs = block(hidden_states)
hidden_states = outputs[0]
hidden_states = model.transformer.ln_f(hidden_states) # final layer norm
hidden_states.shape
lm_logits = model.lm_head(hidden_states)
lm_logits.shape # (batch_size, sequence_length, vocab_size)
์ด๋ ๊ฒ ์ธ ๊ฐ์ง ๊ณผ์ ์ ๊ฑฐ์ณ์ ๋ชจ๋ธ์ Causal-LM, ์ด์ ๋จ์ด๋ค๋ก๋ถํฐ ๋ค์ ๋จ์ด๋ฅผ ์์ธกํ๋ Conditional next word distribution์ ํ์ตํ๊ฒ ๋ฉ๋๋ค. ์ถ๋ก ์์๋ ์ด์ ๋ถํฐ ์๊ฐํ decoding ๋ฐฉ๋ฒ๋ก ์ผ๋ก ๊ณ์ฐ๋ ํ๋ฅ ์ ์ด๋ป๊ฒ ์ฌ์ฉํ๋๋ ์ด๊ฒ์ด ๊ฐ๋ฆฌ๊ฒ ์ง์!
Greedy Search
Greedy search๋ ๋จ์ํ๊ฒ ๋ค์ ๋จ์ด๋ก ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๊ฐ์ง๋ ๋จ์ด๋ฅผ ์ ํํฉ๋๋ค. ์์์ผ๋ก ์ด๋ฅผ ๋ค์ ์ฐ๋ฉด,
$$w_t=\underset{w}{\mathrm{argmax}}{P(w|w_{1:t-1})}$$
- $W_0$๋ ์๋ต๋ ๊ฒ ๊ฐ์ต๋๋ค.
์ด๋ฅผ ์ด๋ฏธ์ง๋ก ๊ทธ๋ ค๋ณด๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
The
๋ผ๋ ๋จ์ด๋ก๋ถํฐ ์์ํ์ฌ ์๊ณ ๋ฆฌ์ฆ์ ํ์์ ์ผ๋ก(greedily) ๋ค์์ผ๋ก ์ฌ ๋จ์ด๋ก ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๊ฐ์ง๋ nice
๋ฅผ ์ ํํฉ๋๋ค. ์ด๋ ๊ฒ ์ข
๋ฃ ์์ ๊น์ง ํ์์ ์ผ๋ก ์ ํํ๋ฉด ์์ ์์์์ ์ต์ข
์ ์ผ๋ก ์์ฑ๋ word sequence๋ (The
, nice
, woman
)์ด ๋ ๊ฒ์
๋๋ค.
- ํด๋น word sequence๊ฐ ๋์ฌ ํ๋ฅ ์ $0.5 \times 0.4 = 0.2$๋ก ๊ฝค๋ ๋์ต๋๋ค.
๊ตฌํ ์์ธ๋ ์ ๊ฐ ๋ฏ์ด๋ณด๋ฉฐ ์์๋ธ ์ ๋ณด๋ฅผ ๋ค๋ฃจ๋ Chapter 2์์ ๋ค๋ฃจ๊ณ huggingface์์ ์ด๋ป๊ฒ ์ฌ์ฉํ ์ ์๋์ง ์์๋ด ์๋ค.
input_ids = tokenizer.encode("๊ทผ์ก์ด ์ปค์ง๊ธฐ ์ํด์๋", return_tensors="pt")
# CLM์ผ๋ก ๋ฌธ์ฅ์ ์์ฑ (output length๊ฐ 128์ ๋๋ฌํ ๋ ๊น์ง)
greedy_output = model.generate(input_ids, max_length=128)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))
์ค... ์ ์์ฑํด๋๊ตฐ์ ใ ใ . ํ์ง๋ง ์์ธํ ๋ณด๋ฉด ๊ท์น์ ์ธ ์ํ์ต๊ด์ด ์ค์ํ๋ค๊ณ ๋ด์ฉ์ ๋ฐ๋ณตํ๋ ๋ฌธ์ ๊ฐ ๋ณด์ด๋๊ตฐ์...!
์ด๋ ์ผ๋ฐ์ ์ผ๋ก natural language generation์์ ์ผ๋ฐ์ ์ธ ๋ฌธ์ ์ด๋ฉฐ greedy search, beam search์ ๊ฐ์ maximization ๊ธฐ๋ฒ์์ ํจ์ฌ ๋ ์ฌํ๊ฒ ๋ฐ์๋ฉ๋๋ค.
Greedy search์ ์ฃผ๋ ๋จ์ ์ ๋ฎ์ ํ๋ฅ ์ฌ์ด์ ์จ๊ฒจ์ง ๋์ ํ๋ฅ ์ ๋จ์ด๋ฅผ ๋์น๋ ๊ฒ
์
๋๋ค. ์์ ์์์์๋ (The
, dog
, has
)๋ฅผ ๋์ณค์ฃ . ์ด ๋ฌธ์ฅ์ ์ฌ์ค $0.4 \times 0.9 = 0.36$์ผ๋ก ์์ ๋ฌธ์ฅ๋ณด๋ค ์กฐ๊ธ ๋ ๊ฐ๋ฅ์ฑ์ด ์๋ ๋ฌธ์ฅ์
๋๋ค.
๊ณ ๋ง๊ฒ๋ beam search๊ฐ ์ ๋ฌธ์ ๋ฅผ ์กฐ๊ธ ๋์ด์ค๋๋ค!
- alleviate์ ๋๋ค. ํด๊ฒฐํด์ฃผ์ง๋ ์์ต๋๋ค...
Beam Search
Beam search๋ ๊ฐ time step์์ ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ์ด ๋์ ๊ฐ์ค์ num_beams
๋งํผ ์ ์งํ๊ณ ๊ฒฐ๊ตญ ์ ์ฒด ํ๋ฅ ์ด ๊ฐ์ฅ ๋์ ๊ฐ์ค(hypothesis)๋ฅผ ์ ํํ์ฌ ์จ๊ฒจ์ง ๋์ ํ๋ฅ ์ word sequence๋ฅผ ๋์น ์ํ์ ์ค์
๋๋ค.
์๋ ์์๋ num_beams
๊ฐ 2์ผ ๋ beam search์ ๋์ ๊ณผ์ ์
๋๋ค.
Time step 1์์ ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ์ด ๋์ ๊ฐ์ค์ (The
, nice
)๋ก ํ๋ฅ ์ด 0.5, ๊ทธ ๋ค์์ผ๋ก ๋์ ํ๋ฅ ์ ๋ณด์ด๋ ๊ฐ์ค์ (The
, dog
)๋ก ํ๋ฅ ์ด 0.4์
๋๋ค. greedy search์์๋ top-1๋ง ์ ํํ๊ธฐ ๋๋ฌธ์ ์๋ ๊ฐ์ค์ด ๋ฌด์๋์์ง๋ง beam search์์๋ num_beams
๋งํผ ๊ฐ์ค์ ์ ์งํ๊ธฐ ๋๋ฌธ์ ๋ ๋ฒ์งธ๋ก ๋์ ํ๋ฅ ์ ๊ฐ์ง๋ ๊ฐ์ค (The
, dog
)๋ฅผ ๊ธฐ๊ฐํ์ง ์๊ณ ์ ์งํฉ๋๋ค. (์ด๋ป๊ฒ ์ ์งํ๋์ง๋ Ch2)
Time step 2์์ (The
, dog
, has
)๋ 0.36์ ํ๋ฅ ์ ๊ฐ์ง๋ ๊ฐ์ค์ด๊ณ greedy search์ ๊ฒฐ๊ณผ์๋ (The
, nice
, woman
)์ 0.2์ ํ๋ฅ ์ ๊ฐ์ง๋ ๊ฐ์ค์
๋๋ค. ์ด๋์ ํ๋ฅ ์ด ๋ค์งํ์ฃ ?
Beam search๋ ํญ์ greedy search๋ณด๋ค ๋์ ํ๋ฅ ๋ก output sequence๋ฅผ ์ฐพ์์ค๋๋ค. ํ์ง๋ง ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ์ด ์๋ ์ถ๋ ฅ์ ์ฐพ๋ ๊ฒ์ ๋ณด์ฅ๋์ง ์์ฃ . (sub-optimal solution)
transformers์์์ ์์๋ฅผ ๋ด ์๋ค!
beam_output = model.generate(
input_ids,
max_length=128,
num_beams=5,
early_stopping=True
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))
์... ํ์ง๋ง ์ถ๋ ฅ์ ์ฌ์ ํ ๋์ผํ word sequence์ ๋ฐ๋ณต์ด ํฌํจ๋๋ ๊ตฐ์...
์ด์ ๋ํ ๊ฐ๋จํ ํด๊ฒฐ์ฑ
์ Paulus ์ฐ๊ตฌ์ง์ด ๋์
ํ n-grams penalty
(a.k.a word sequences of n words)๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค.
- A Deep Reinforced Model for Abstractive Summarization, Paulus et al. (2017)
- OpenNMT: Open-Source Toolkit for Neural Machine Translation, Klein et al. (2017)
๊ฐ์ฅ ํํ n-grams penalty
๋ ์ด๋ฏธ ๋ณธ n-gram์ ์์ฑํ ์ ์๋ ๋ค์ ๋จ์ด์ ํ๋ฅ ์ ์๋์ผ๋ก 0์ผ๋ก ์ค์ ํ์ฌ n-gram์ด ๋ ๋ฒ ๋ค์ ๋ํ๋์ง ์๋๋ก ํ๋ ๊ฒ์
๋๋ค.
no_repeat_ngram_size
๋ฅผ 2๋ก ์ค์ ํ์ฌ ๋์ผํ n-gram์ด 2๋ฒ ์ด์ ๋ฐ๋ณต๋์ง ์๋๋ก ์์ ํด๋ณด์ฃ !
beam_output = model.generate(
input_ids,
max_length=128,
num_beams=5,
no_repeat_ngram_size=2,
early_stopping=True
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))
๋ฐ๋ณต์ ๋ฌธ์ ๋ ํด๊ฒฐ๋์๊ตฐ์! (๋ค๋ง ๊ฐ๋ถ ํ๋์๋์ฐจ... ์ ๋ ๋ค๋ง ๊ทผ์ก์ด ์ปค์ง๋ ๋ฐฉ๋ฒ์ ์๊ณ ์ถ์์ต๋๋ค๋ง...)
๋ค๋ง patricks์ ๋ฐ๋ฅด๋ฉด n-gram penalty๋ ์ฃผ์ํด์ ์ฌ์ฉํด์ผ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด New York ์์ ๋ํด ์์ฑ๋ ๊ธฐ์ฌ๋ 2-gram penalty๋ฅผ ์ฌ์ฉํ๋ฉด ์ ์ฒด ํ ์คํธ์์ ๋์ ์ด๋ฆ์ด ํ ๋ฒ๋ง ๋ํ๋๊ธฐ ๋๋ฌธ์ ๋๋ค.
Beam search์ ๋ ๋ค๋ฅธ ์ค์ํ ๊ธฐ๋ฅ์ ์์ฑ ํ top beams๋ฅผ ๋น๊ตํ๊ณ ๋ชฉ์ ์ ๊ฐ์ฅ ์ ๋ง๋ generated beam์ ์ ํํ ์ ์๋ค๋ ์ ์ ๋๋ค.
๐ค transformers์์ num_return_sequences
๋งค๊ฐ๋ณ์๋ฅผ ์ธํ
ํ๋ฉด ์์ ์์
์ ์ํํ ์ ์์ต๋๋ค.
-
num_return_sequences
๋ ํญ์num_beams
๋ณด๋ค ์์์ผ ํฉ๋๋ค.
beam_outputs = model.generate(
input_ids,
max_length=128,
num_beams=5,
no_repeat_ngram_size=2,
num_return_sequences=5,
early_stopping=True
)
# now we have 3 output sequences
print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
decoded_text = tokenizer.decode(beam_output, skip_special_tokens=True)
print(f"{i}: {decoded_text}", end="\n\n")
์... ๋ฐํ๋ฐ์์ง๋ง ๊ฐ beam๋ค์ด ํฌ๊ฒ ๋ค๋ฅด์ง ์์ต๋๋ค.
Open-ended generation์์ ์ต๊ทผ์ beam search๊ฐ ์ต์ ์ด ์๋ ์ ์๋ค๋ ๋ช ๊ฐ์ง ์ด์ ๊ฐ ์ ๊ธฐ๋์์ต๋๋ค.
- Beam search๋ ๊ธฐ๊ณ ๋ฒ์ญ์ด๋ ์์ฝ๊ฐ์ด ์ํ๋ ์์ฑ์ ๊ธธ์ด๊ฐ ์ด๋ ์ ๋ ์์ธก ๊ฐ๋ฅํ ์์ ์์ ๋งค์ฐ ์ ๋์ํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ํ๋ ์ถ๋ ฅ์ ๊ธธ์ด๊ฐ ํฌ๊ฒ ๋ฌ๋ผ์ง ์ ์๋ open-ended generation์ ๊ฒฝ์ฐ(๋ํ ํน์ story ์์ฑ) ๊ทธ๋ ์ง ์์ต๋๋ค.
- ์์์ ํ์ธํ๋ฏ beam search๋
repetitive generation
์ ์ทจ์ฝํฉ๋๋ค. ์ด๋ n-gram ํน์ ๋ค๋ฅธ penalty๋ก ์ ์ ํ ์กฐ์ ํ๊ธฐ๊ฐ ์ด๋ ต์ต๋๋ค. - Holtzman์ ๋ฐ๋ฅด๋ฉด, High quality human language๋ ๋ค์ ๋จ์ด๊ฐ ๊ฐ์ฅ ๋๊ฒ ์ฌ ํ๋ฅ ๋ถํฌ๋ฅผ ๋ฐ๋ฅด์ง ์์ต๋๋ค. ์ฆ ์ธ๊ฐ์ ์์ฑ๋ ํ ์คํธ๊ฐ ์ฐ๋ฆฌ๋ฅผ ๋๋ผ๊ฒ ํ๊ณ (surprise) ์ง๋ฃจํ๊ฑฐ๋(boring) ์์ธกํ ์ ์๊ธฐ๋ฅผ(not to be predictable) ์ํฉ๋๋ค. ์ ์๋ BeamSearch๋ก ์์ฑ๋ text๊ฐ ๋ ๋๋๋ค๋ ๊ฒ์ ์๋ plot์ผ๋ก ๋ณด์ฌ์ค๋๋ค.
์, ์ง๋ฃจํ text๋ ๊ทธ๋ง ์์ฑํ๊ณ randomness๋ฅผ ๋์ ํฉ์๋ค :)
์ํ๋ง์ ์ฌ์ฉํ ์ธ์ด ์์ฑ์ ๋ ์ด์ ๊ฒฐ์ ์ ์ด์ง ์์ต๋๋ค.
- Maximization ๊ธฐ๋ฒ(greedy, beam)์ ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๊ฐ์ง๋ ๊ฐ์ค๋ง์ ํํ์ต๋๋ค.
๋จ์ด car
์ beam์ 3๋งํผ ๋๋ฆฌ์ง ์์ผ๋ฉด ์ด์ maximization ๊ธฐ๋ฒ์์๋ ์ด๋ ํ ๊ฒฝ์ฐ์๋ ์ ๋๋ก ์ฑํ๋์ง ์์ต๋๋ค. ํ์ง๋ง ์ ๋ง ๋ฎ์ ํ๋ฅ ๋ก ์กฐ๊ฑด๋ถ ํ๋ฅ ๋ถํฌ $P(w|`the`)$์์ ๋จ์ด car
๊ฐ ์ถ์ถ๋ ์ ์์ผ๋ฉฐ ๋ค์ ๋จ์ด์ธ drives
, is
, turns
๊ฐ ์กฐ๊ฑด๋ถ ํ๋ฅ ๋ถํฌ $P(w|`the`,`car)$์์ ์ถ์ถ๋ ๊ฒ ์
๋๋ค.
๐ค transformers์์ do_sample
์ต์
์ ํ์ฑํ์ํค๊ณ top-k sampling์ ๋นํ์ฑํ์์ผ์ ์๋ฅผ ๊ตฌํํ ์ ์์ต๋๋ค.
import random
import torch
import numpy as np
def set_seed(seed: int = 42):
"""Seed fixer (random, numpy, torch)
Args:
seed (:obj:`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed()
sample_output = model.generate(
input_ids,
do_sample=True,
max_length=128,
top_k=0,
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
์ด... ๋ด์ฉ์ ์ฐจ์นํ๊ณ ๋ฐ๋ณต ๋ฌธ์ ๋ ์๋ณด์ด๋ค์! ํ์ง๋ง ํํ์ด ์ด์ํฉ๋๋ค.
- ๊ต์ ์ด ์ด๋ ต๋ค๊ณ ํ์ฌ ํ๋ ๊ฒ์ด๋ค.
- ํนํ ๊ต์ ์ ์๊ฐ์ ํฌ์ํ๋ค
- ๋ณด๋ฉด ๊ต์ ์์ ์ ์์ฌ์ด ์์๊ฒ ๋๋ ์ฆ์ด ์๋ค
์ด๋ word sequence๋ฅผ samplingํ ๋ ์๊ธฐ๋ ํฐ ๋ฌธ์ ์ ๋๋ค. ๋ชจ๋ธ์ ์ข ์ข ์ผ๊ด์ฑ์์ด ํก์ค์์คํฉ๋๋ค.
์๋ฅผ ํด๊ฒฐํ ํธ๋ฆญ์ ๋ถํฌ $P(w|w_{1:t-1})$์ sharpํ๊ฒ ๋ง๋๋ ๊ฒ์ ๋๋ค.
- ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๊ฐ์ง๋ ๋จ์ด์ likelihood๋ฅผ ๋์ด๊ณ
- ๊ฐ์ฅ ๋ฎ์ ํ๋ฅ ์ ๊ฐ์ง๋ ๋จ์ด์ likelihood๋ฅผ ๋ฎ์ถ๋ ๊ฒ
์ ํธ๋ฆญ์ softmax์ temperature
๋ผ๊ณ ๋ถ๋ฆฝ๋๋ค.
temperature๋ฅผ ์ ์ฉํ ์์์ ๋ํ ์๊ฐํ์ ๋๋ค.
temperature๋ฅผ ์ ์ฉํ๊ธฐ ์ ์๋ $P(w|`the`)$์์ car
๊ฐ ๋ฝํ ํ๋ฅ ์ด 0.1์ด์์ง๋ง ์ง๊ธ์ 0.02์
๋๋ค. ๋ฎ์์ง ๋งํผ ๋ฝํ๊ธฐ๋ ๋ ํ๋ค๊ฒ ์ฃ ?
sample_output = model.generate(
input_ids,
do_sample=True,
max_length=128,
top_k=0,
temperature=0.7,
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Maximization์ ๊ฒฐ๊ณผ์ ์ ์ฌํ๋ฉด์ ๋ฐ๋ณต์ ์ํ๊ณ ๋ค๋ฅธ ๋ด์ฉ๊น์ง ์ถ๊ฐ๋์์ต๋๋ค! (๋ฌผ๋ก ๊ทผ์ก๊ณผ๋ ์์ง๋ ๊ด๋ จ์ด ์ ์ต๋๋ค... ๊ทธ๋๋ ์ผ๊ด์ฑ์ ๊ฐ์ ๋์๊ตฐ์.)
temperature๋ฅผ ์ ์ฉํ๋ฉด ๋ถํฌ๋ฅผ ๋ randomํ๊ฒ ๋ง๋ค ์ ์์ง๋ง 0์ผ๋ก ์ค์ ํ๋ฉด greedy decoding๊ณผ ๋์ผํด์ง๋๋ค. ๊ทธ๋ฌ๋ฉด ์ด์ ๊ณผ ๊ฐ์ ๋ฌธ์ ๋ฅผ ๋ค์ ๊ฒช๊ฒ ๋๊ฒ ์ง์.
Top-K Sampling
Fan ์ฐ๊ตฌ์ง์ ์์ฃผ ๊ฐ๋จํ์ง๋ง ๊ฐ๋ ฅํ sampling scheme์ธ Top-K๋ฅผ ์๊ฐํ์ต๋๋ค.
Top-K sampling์์ ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ์ด ๋์ K๊ฐ์ ๋ค์ ๋จ์ด๋ filtering๋๊ณ probability mass๋ K๊ฐ์ ๋ค์ ๋จ์ด์ ๋ํด์๋ง ์ฌ๋ถ๋ฐฐ๋ฉ๋๋ค. GPT2๊ฐ ์ด sampling scheme๋ฅผ ํํ๊ณ story generation์์ ์ฑ๊ณต์ ์ด์๋ ์์ธ ์ค ํ๋๋ก ํ๊ฐ๋ฉ๋๋ค.
- ์ ์ฒด ๋จ์ด ๋ถํฌ์์ ์ํ๋งํ๋ ๊ฒ์ด ์๋๋ผ ๊ณ ์ ๋ ์์ K๊ฐ์ ๋จ์ด์์ sampling ์ํ
Time step 1์์ Top-6๊ฐ๋ฅผ ์ ์ธํ ๋๋จธ์ง people
, big
, house
, cat
์ ์์ฑ ๋์์์ ์ ๊ฑฐํฉ๋๋ค. (Top-K๊ฐ๋ง filtering, Vocab์ pick-up)
Step 1์์๋ ์ ์ฒด์ 2/3, step 2์์๋ ๊ฑฐ์ ๋ชจ๋ probability mass๋ฅผ ํฌํจํฉ๋๋ค.
Time step 2์์ Top-6๊ฐ๋ฅผ ์ ์ธํ ๋๋จธ์ง not
, the
, small
, told
์ด์ํ ๋จ์ด๋ค์ ์ฑ๊ณต์ ์ผ๋ก ์ ์ธํ๊ณ ์ถ์ถํ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
sample_output = model.generate(
input_ids,
do_sample=True,
max_length=128,
top_k=50,
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
์ ์ผ ๊ด์ฐฎ์ ๊ฒฐ๊ณผ์ธ ๊ฒ ๊ฐ์ต๋๋ค! ๊ฐ์ฅ ์ธ๊ฐ๊ฐ์ด ์์ฑ๋ ๊ฒ ๊ฐ๊ตฐ์. Top-K sampling์ ํ ๊ฐ์ง ๋ฌธ์ ๋ next word distribution $P(w|w_{1:t-1})$์์ filtering๋๋ ๋จ์ด์ ์๋ฅผ dynamicํ๊ฒ ์ ์ฉํ์ง ์๋ ๋ค๋ ์ ์ ๋๋ค.
- ๊ณ ์ ๋ K๋ฅผ ์ฌ์ฉํ๊ธฐ์ ๋ฌธ์
์ ๊ทธ๋ํ์์ ์ค๋ฅธ์ชฝ์ ๊ฒฝ์ฐ ๋งค์ฐ sharpํ ๋ถํฌ์์ sampling๋์ง๋ง ์ผ์ชฝ์ ๊ฒฝ์ฐ์๋ ๋ flatํ ๋ถํฌ์์ sampling๋๊ธฐ์ ๋ฌธ์ ๊ฐ ๋ ์ ์์ต๋๋ค.
Step 1์์ Top-K๋ people
, big
, house
, cat
์ ๊ฐ์ ๊ฐ๋ฅ์ฑ์๋ ํ๋ณด๊ตฐ๋ค์ ์ ์ธํ์ต๋๋ค. ๋ฐ๋๋ก Step 2์์๋ ๋จ์ด์ sample pool(In top-k)์ ๋ถ์ ํฉํ ๋จ์ด down
, a
๋ฅผ ํฌํจํฉ๋๋ค. ๋๋ฌธ์ sample pool์ ๊ณ ์ ๋ ํฌ๊ธฐ K๋ก ์ ํํ๋ ๊ฒ์ ๋ชจ๋ธ์ด sharp distribution์ ๋ํด ํก์ค์์ค(gibberish)ํ ์ํ์ด ์๊ณ flat distribution์ ๋ํด ์ฐฝ์์ฑ(creativity)์ด ์ ํ๋ ์ ์์ต๋๋ค.
์ ์ง๊ด์ด Ari Holtzman ์ฐ๊ตฌ์ง๋ค์ด ์ ์ํ Top-p ํน์ nucleus sampling์ผ๋ก ์ด์ด์ง๋๋ค.
Top-p (nucleus) Sampling
๊ฐ์ฅ ๋์ K๊ฐ์ ๋จ์ด๋ฅผ ์ ํํ๋ ๋์ Top-P sampling์ ๋์ ํ๋ฅ ์ด ํ๋ฅ p๋ฅผ ์ด๊ณผํ๋ ๊ฐ๋ฅํ ๊ฐ์ฅ ์์ ๋จ์ด ์งํฉ์์ ์ ํํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ probability mass๋ ์ด ๋จ์ด set ์ฌ์ด์ ์ฌ๋ถ๋ฐฐ๋ฉ๋๋ค. ์ด๋ฐ ์์ผ๋ก ๋จ์ด ์งํฉ์ ํฌ๊ธฐ(a.k.a the number of words in the set)์ ๋ค์ ๋จ์ด์ ํ๋ฅ ๋ถํฌ์ ๋ฐ๋ผ ๋์ ์ผ๋ก ์ฆ๊ฐํ๊ฑฐ๋ ๊ฐ์ํ ์ ์์ต๋๋ค.
์๋ฅผ ์๊ฐํํด๋ด ์๋ค!
$p=0.92$๋ก ์ค์ ํ๊ฒ ์ต๋๋ค. Top-p sampling์ probability mass์ 92%๋ฅผ ์ด๊ณผํ๋ ๋จ์ด์ minimum number๋ฅผ ๊ณ์ฐํฉ๋๋ค. ์ด๋ฅผํ
๋ฉด ์์์ cat
์ ์ ์ธํ ๋จ์ด์ prob mass์ ํฉ์ 0.94๋ก ์ค์ ํ p๋ณด๋ค ์ปค์ง๊ฒ ๋ฉ๋๋ค. ์ฆ time step 1์์๋ 9๊ฐ์ ๋จ์ด๋ฅผ ๊ณ ๋ฅด๊ณ time step 2์์๋ drives
, is
, turns
๋ง์ผ๋ก๋ 97%์
๋๋ค. ๋๋ฌธ์ 3๊ฐ์ ๋จ์ด๋ก ๊ณ ์ ํ sampling์ ์ํํฉ๋๋ค. Top-K์์ ๊ณ ์ ์ ์ธ K๋ก samplingํ ๊ฒ๊ณผ ๋ค๋ฅด๊ฒ Top-p์์๋ next word distribution์ ๋ฐ๋ผ dynamicํ๊ฒ sampling pool์ ๊ฒฐ์ ํ ์ ์์ต๋๋ค.
sample_output = model.generate(
input_ids,
do_sample=True,
max_length=128,
top_p=0.92,
top_k=0,
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
์ข์ต๋๋ค! ๋งจ ์ฒ์ samplingํ์ ๊ฒฐ๊ณผ๋ณด๋ค๋ ํจ์ฌ ๋ ์ฌ๋๋ค์ ์ก์ต๋๋ค. (๋ด์ฉ์...)
์ด๋ก ์ ์ผ๋ก Top-p๋ Top-k๋ณด๋ค ๋ ์ฐ์ํด๋ณด์ด์ง๋ง ์ค์ ๋ก๋ ๋ ๋ฐฉ๋ฒ ๋ชจ๋ ์ ๋์ํ๊ณ Top-p๋ Top-k์ ํจ๊ป ์ฌ์ฉํ ์ ์์ต๋๋ค. Top-K๋ ๋งค์ฐ ๋ฎ์ ์์์ ๋จ์ด๋ฅผ ํผํ๋ฉด์ ์ผ๋ถ ๋์ ์ ํ์ ํ์ฉํ ์ ์์ต๋๋ค.
๋ง์ง๋ง์ผ๋ก ๋
๋ฆฝ์ ์ผ๋ก ์ํ๋ง๋ ์ฌ๋ฌ ์ถ๋ ฅ์ ์ป๊ธฐ ์ํด ๋งค๊ฐ๋ณ์ num_return_sequences
๋ฅผ 1๋ณด๋ค ํฌ๊ฒ ๋ค์ ์ค์ ํ ์ ์์ต๋๋ค.
sample_outputs = model.generate(
input_ids,
do_sample=True,
max_length=128,
top_p=0.95,
top_k=50,
num_return_sequences=3,
)
print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
decoded_text = tokenizer.decode(sample_output, skip_special_tokens=True)
print(f"{i}: {decoded_text}", end="\n\n")
Conclusion
- top-p, top-k sampling์ open-ended language generation์์ ๊ธฐ์กด์ greedy-and beam search๋ณด๋ค ๋ ์ ์ฐฝํ text๋ฅผ ์์ฑํ๋ ๊ฒ์ผ๋ก ๋ณด์
- ์ต๊ทผ์ greedy ๋ฐ beam search์ ๋ช ๋ฐฑํ ๊ฒฝํจ(์ฃผ๋ก ๋ฐ๋ณต์ ์ธ word sequence ์์ฑ)์ด decoding methodology๋ณด๋ค๋ model(ํนํ ๋ชจ๋ธ์ด ํ๋ จ๋๋ ๋ฐฉ์)์ ์ํด ๋ฐ์ํ๋ค๋ ์ฆ๊ฑฐ๊ฐ ๋ ๋ง์ด ์์.
- ๋ํ top-k ๋ฐ top-p sampling๋ repetitive word sequence ์์ฑ์์ ์์ ๋กญ์ง ๋ชปํ๋ ๊ฒ์ผ๋ก ๋ณด์
- Welleck์ 2019 ์ฐ๊ตฌ์ ์ํ๋ฉด ์ ์๋ ์ฌ๋์ ํ๊ฐ์ ๋ฐ๋ฅด๋ฉด ๋ชจ๋ธ์ ํ๋ จ ๋ชฉํ๋ฅผ ์กฐ์ ํ ๋ Beam search๊ฐ Top-p sampling๋ณด๋ค ๋ ์ ์ฐฝํ text๋ฅผ ์์ฑํ ์ ์์์ ๋ณด์
- Open-ended language generation์ ๋น ๋ฅด๊ฒ ๋ฐ์ ํ๋ ๋ถ์ผ์ด๋ฉฐ ์ฌ๊ธฐ์ ๋ชจ๋ ๊ฒฝ์ฐ์ ์ ์ฉํ ์ ์๋ ๋ฐฉ๋ฒ์ด ์๋ ๊ฒฝ์ฐ๊ฐ ๋ง์. ๋๋ฌธ์ ํน์ ์ฌ์ฉ ์ฌ๋ก์ ๊ฐ์ฅ ์ ํฉํ ๋ฐฉ๋ฒ์ด ๋ฌด์์ธ์ง๋ฅผ ํ์ธํด์ผ ํ๋ค.
Appendix
์์์ ์ธ๊ธํ์ง ์์ ์์ฑ ๋ฉ์๋์ ๋ํ ๋ช ๊ฐ์ง ์ถ๊ฐ ๋งค๊ฐ๋ณ์ ์๊ฐ
-
min_length
: min_lenght์ ๋๋ฌํ๊ธฐ ์ ์ ๋ชจ๋ธ์ด EOS token์ ์์ฑํ์ง ์๋๋ก ๊ฐ์ ํ๋ ๋ฐ ์ฌ์ฉํ ์ ์์- ์์ฝ์์ ๋งค์ฐ ์์ฃผ ์ฌ์ฉ๋์ง๋ง ์ฌ์ฉ์๊ฐ ๋ ๊ธด ์ถ๋ ฅ์ ์ํ ๊ฒฝ์ฐ ์ผ๋ฐ์ ์ผ๋ก ์ ์ฉํ ์ ์์
-
repeat_penalty
: ์ด๋ฏธ ์์ฑ๋์๊ฑฐ๋ context์ ์ํ๋ ๋จ์ด์ penalty๋ฅผ ์ ์ฉํ๋๋ฐ ์ฌ์ฉ. Keskar et al., (2019)์ ์ํด ์ฒ์์ผ๋ก ์๊ฐ๋์์ผ๋ฉฐ Welleck et al., (2019)์ training objective๋ก๋ ์ฌ์ฉ๋จ. ๋ฐ๋ณต์ ๋ฐฉ์งํ๋๋ฐ ๋งค์ฐ ํจ๊ณผ์ ์ผ ์ ์์ง๋ง ๋ค์ํ ๋ชจ๋ธ ๋ฐ ์ฌ์ฉ ์ฌ๋ก์ ๋งค์ฐ ๋ฏผ๊ฐํ ๊ฒ์ผ๋ก ๋ณด์. ํด๋น ๋์ค์ปค์ ์ฐธ๊ณ . -
attention_mask
: padded token์ maskํ๋๋ฐ ์ฌ์ฉ -
pad_token_id
,bos_token_id
,eos_token_id
: ๋ชจ๋ธ์ ๊ธฐ๋ณธ์ ์ผ๋ก ํด๋น ํ ํฐ์ด ์๋ ๊ฒฝ์ฐ ์ฌ์ฉ์๋ ๋ค๋ฅธ token id๋ฅผ ์๋์ผ๋ก ์ ํํ์ฌ ๋ํ๋ผ ์ ์์.