Introduction

๋ณธ ํฌ์ŠคํŒ…์—์„œ๋Š” ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•˜์—ฌ ์•Œ์•„๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

ํ•ด๋‹น ํฌ์ŠคํŒ…์€ patrick von platen๋‹˜์˜ ํฌ์ŠคํŒ…์„ ๋ฒˆ์—ญํ•œ ๋‚ด์šฉ๊ณผ ์ œ๊ฐ€ ์ง์ ‘ ๐Ÿค— transformers๋ฅผ ๋œฏ์–ด๋ณธ ๋‚ด์šฉ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ž‘์„ฑํ•˜์˜€์Šต๋‹ˆ๋‹ค.

Using different decoding methods for language generation with Transformers

Introduction

์ตœ๊ทผ ๋ช‡ ๋…„ ๋™์•ˆ OpenAI์˜ GPT-2 ๋ชจ๋ธ๊ณผ ๊ฐ™์ด ์ˆ˜๋ฐฑ๋งŒ ๊ฐœ์˜ ์›น ํŽ˜์ด์ง€์—์„œ ํ›ˆ๋ จ๋œ large-scale transformer ๊ธฐ๋ฐ˜ ์–ธ์–ด ๋ชจ๋ธ์˜ ๋“ฑ์žฅ์œผ๋กœ open-ended language generation์— ๋Œ€ํ•œ ๊ด€์‹ฌ์ด ๋†’์•„์กŒ์Šต๋‹ˆ๋‹ค. Conditioned open-ended language generation์˜ ๊ฒฐ๊ณผ๋Š” ๊ต‰์žฅํžˆ ์ธ์ƒ์ ์ž…๋‹ˆ๋‹ค.

2017๋…„์— transformer๊ฐ€ ๋“ฑ์žฅํ•œ ์ด๋ž˜๋กœ ์•„ํ‚คํ…์ณ๋ฅผ ์ˆ˜์ •ํ•˜๋Š” ๋ฐฉ๋ฒ• ๋ฐ ๋ฐฉ๋Œ€ํ•œ unsupervised ํ•™์Šต ๋ฐ์ดํ„ฐ(self-supervised๋ฅผ ์œ„ํ•œ)๋“ค์ด ์žˆ์ง€๋งŒ ๋” ๋‚˜์€ ๋””์ฝ”๋”ฉ ๋ฐฉ๋ฒ•(better decoding methods) ๋˜ํ•œ ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ–ˆ์Šต๋‹ˆ๋‹ค.

์ด ๋ธ”๋กœ๊ทธ ํฌ์ŠคํŠธ๋Š” ๋‹ค์–‘ํ•œ ๋””์ฝ”๋”ฉ ์ „๋žต์— ๋Œ€ํ•œ ๊ฐ„๋žตํ•œ ๊ฐœ์š”๋ฅผ ์ œ๊ณตํ•˜๊ณ  ๋” ์ค‘์š”ํ•œ ๊ฒƒ์€ ์œ ๋ช…ํ•œ ๐Ÿค— transformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์•„์ฃผ ์ ์€ ๋…ธ๋ ฅ์œผ๋กœ ์ด๋ฅผ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค!

  • jinmang2: ๋˜ํ•œ ์ œ ๋ธ”๋กœ๊ทธ ํฌ์ŠคํŒ…์—์„œ ์—ฌ๋Ÿฌ๋ถ„๋“ค์€ ๐Ÿค— transformers์—์„œ Decoding methods๊ฐ€ ์–ด๋–ป๊ฒŒ ๊ตฌํ˜„๋˜์–ด ์žˆ๋Š”์ง€๋„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ณธ ๊ฒŒ์‹œ๊ธ€์—์„œ ๋‹ค๋ฃจ๋Š” ๋ชจ๋“  ๊ธฐ๋Šฅ๋“ค(functionalities)์€ ๋ชจ๋‘ auto-regressive language generation(๋” ์ž์„ธํ•œ ๋‚ด์šฉ์€ Jay Alammar๋‹˜์˜ ํฌ์ŠคํŒ…์„ ํ™•์ธํ•ด์ฃผ์„ธ์š”)์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. auto-regressive๋ž€ ๊ฐ„๋‹จํžˆ ๋งํ•ด ์•„๋ž˜ ๊ฐ€์ •์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋‘” ๋ฐฉ๋ฒ•๋ก ์ž…๋‹ˆ๋‹ค.

Assumption
==========
The probability distribution of a word sequence can be decomposed into
the product of conditional next word distributions.

word sequence, ์ฆ‰, ๋‹จ์–ด๋กœ ์ด๋ฃจ์–ด์ง„ ์ˆ˜์—ด(๋ฌธ์žฅ์ด๋ผ๊ณ  ์ดํ•ดํ•˜์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค)์€ ๋‹ค์Œ ๋‹จ์–ด๊ฐ€ ๋‚˜์˜ฌ ์กฐ๊ฑด๋ถ€ ๋ถ„ํฌ๋“ค์˜ ๊ณฑ์œผ๋กœ ๋ถ„ํ•ด๊ฐ€ ๊ฐ€๋Šฅํ•˜๋‹ค๋ผ๋Š” ์ „์ œ๋ฅผ ๊น”๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

์ˆ˜์‹์œผ๋กœ ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

$$P(w_{1:T}|W_0)=\prod_{t=1}^{T}P(w_t|w_{1:t-1},W_0),\;\text{with}\, w_{1:0}=\emptyset$$

  • $W_0$์€ initial context word sequence
  • $T$๋Š” ๋ฌธ์žฅ์˜ ๊ธธ์ด์ž…๋‹ˆ๋‹ค.
    • patrick๋‹˜ ํฌ์ŠคํŒ… ๋ณธ๋ฌธ์—์„œ๋Š” ์ƒ์„ฑ๋œ ๋ฌธ์žฅ์˜ ๊ธธ์ด๋ฅผ ์˜๋ฏธํ•˜์‹œ๋Š” ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค. ์ƒ์„ฑ์€ ๋‹จ์–ด(ํ˜น์€ ํ† ํฐ) ๋‹จ์œ„๋กœ ๊ธธ์ด๊ฐ€ 1์”ฉ ๋Š˜์–ด๋‚˜๊ธฐ ๋•Œ๋ฌธ์— on-the-fly๋ผ๋Š” ํ‘œํ˜„์„ ์‚ฌ์šฉํ•˜์‹  ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.
  • Timestep $t$๊ฐ€ $T$๊ฐ€ ๋˜๋ฉด? ๋ฌธ์žฅ์„ ์ „๋ถ€ ๋ถ„ํ•ดํ–ˆ๋‹ค๋Š” ๋œป์ด๊ฒ ์ฃ ?(conditional next word dists.๋กœ) ์ด ๋•Œ๋Š” $P(w_t|w_{1:t-1},W_0)$์—์„œ EOS ํ† ํฐ์ด ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.
    • EOS๋Š” End of Sentence(ํ˜น์€ Sequence)์˜ ์•ฝ์ž๋กœ ๋ฌธ์žฅ์˜ ๋์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

์ด์ œ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ๋„ค ๊ฐ€์ง€ decoding ๋ฐฉ๋ฒ•์— ๋Œ€ํ•˜์—ฌ ์†Œ๊ฐœํ•ด๋“œ๋ฆฌ๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

๋ธ”๋กœ๊ทธ ํฌ์ŠคํŒ…์˜ ์˜ˆ์‹œ๋ฅผ ๋ณด์…”๋„ ์ข‹๊ณ  ์•„๋ฌด๋ž˜๋„ ํ•œ๊ตญ์–ด๋กœ ๋ฒˆ์—ญ๋œ ํฌ์ŠคํŒ…์ด๊ธฐ ๋•Œ๋ฌธ์— ํ•œ๊ตญ์–ด ์–ธ์–ด ๋ชจ๋ธ๋กœ ์˜ˆ์‹œ๋ฅผ ๋“ค์–ด๋“œ๋ฆฌ๋Š” ๊ฒƒ์ด ๋ณด๊ธฐ ์ข‹์„ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค. ์ด์— ๋Œ€ํ•œ ์„ธํŒ…์„ ์ˆ˜ํ–‰ํ•˜์ฃ .

  • tensorflow๋Š” ์‚ฌ์šฉํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
!pip install -q git+https://github.com/huggingface/transformers.git

KoGPT2

์˜ˆ์‹œ์—์„œ ์‚ฌ์šฉํ•  ๋ชจ๋ธ์€ SKT์—์„œ ๊ฐœ๋ฐœํ•œ KoGPT2์ด๋ฉฐ ์ž์„ธํ•œ ์„ค๋ช…์€ ์•„๋ž˜ ๋งํฌ๋ฅผ ์ฐธ๊ณ ํ•ด์ฃผ์„ธ์š”.

๊ฐ„๋žตํ•œ ์†Œ๊ฐœ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  • Vocab size: 51,200
  • ์ด๋ชจ์ง€, ์ด๋ชจํ‹ฐ์ฝ˜ ๋“ฑ์„ ์ถ”๊ฐ€ํ•˜์—ฌ ํ•ด๋‹น ํ† ํฐ์˜ ์ธ์‹ ๋Šฅ๋ ฅ ๊ฐœ์„ 
  • unused token์„ 100๊ฐœ ์‚ฌ์šฉํ•˜์—ฌ ํ•„์š”ํ•œ task์— ๋”ฐ๋ผ ์ž์œ ๋กญ๊ฒŒ ์ •์˜ ๊ฐ€๋Šฅ
  • metaspace๋Š” โ–
Model # of params Type # of layers # of heads ffn_dim hidden_dims
kogpt2-base-v2 125M Decoder 12 12 3072 768
  • ์‚ฌ์šฉํ•œ ๋ฐ์ดํ„ฐ๋Š” ํ•œ๊ตญ์–ด ์œ„ํ‚ค ๋ฐฑ๊ณผ, ๋‰ด์Šค, ๋ชจ๋‘์˜ ๋ง๋ญ‰์น˜ v1.0, ์ฒญ์™€๋Œ€ ๊ตญ๋ฏผ์ฒญ์› ๋“ฑ ๋‹ค์–‘ํ•œ ๋ฐ์ดํ„ฐ
from transformers import AutoTokenizer, AutoModelForCausalLM

model_path_or_name = "skt/kogpt2-base-v2"
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name, force_download=True)
model = AutoModelForCausalLM.from_pretrained(model_path_or_name, force_download=True)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
tokenizer # Rust๋กœ ๊ตฌํ˜„๋œ `Fast`ํ•œ tokenizer
PreTrainedTokenizerFast(name_or_path='skt/kogpt2-base-v2', vocab_size=51200, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'})
tokenizer.tokenize("๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š”")
['โ–๊ทผ์œก์ด', 'โ–์ปค', '์ง€๊ธฐ', 'โ–์œ„ํ•ด์„œ๋Š”']
model.__class__.__name__ # AutoModelForCausalLM์œผ๋กœ GPT2์˜ CLM class ํ˜ธ์ถœ
'GPT2LMHeadModel'
num_of_parameters = sum(p.numel() for n, p in model.named_parameters())
print(f"{num_of_parameters}") # 125M
125164032

Calculate word probability

Q) ๋‹จ์–ด๋ณ„ ํ™•๋ฅ ์€ ์–ด๋–ป๊ฒŒ ๊ตฌํ•˜๋‚˜์š”?

A) ์œ„์—์„œ GPT2ForLMHeadModel์„ Causal-LM์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•œ ๋ชจ๋ธ๋กœ ๋ถˆ๋ €์ฃ ? ํ•ด๋‹น ๋ชจ๋ธ์ด left-to-right์œผ๋กœ contidional next word distribution์„ ๋ชจ๋ธ๋งํ•ด์ค๋‹ˆ๋‹ค!

LMHeadModel(Causal-LM์„ ํ•˜๊ธฐ ์œ„ํ•œ ๋ชจ๋ธ)์€ ํฌ๊ฒŒ ์„ธ ํŒŒํŠธ๋กœ ๋‚˜๋‰˜์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค.

(1) word token/position embedding

  • ์ธ์ฝ”๋”ฉ๋œ word sequence์— ๋Œ€ํ•ด embedding value๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  • ๋‹จ์–ด(ํ˜น์€ ํ† ํฐ)์˜ ๋œป์„ word token embedding์œผ๋กœ,
  • ๋‹จ์–ด(ํ˜น์€ ํ† ํฐ)์˜ ์œ„์น˜๋ฅผ word position embedding์œผ๋กœ ๋ฒกํ„ฐํ™”ํ•ด์ค๋‹ˆ๋‹ค.
    • position embedding์˜ ๊ฒฝ์šฐ layer์— ํ•ด๋‹น ์ •๋ณด๋ฅผ ๋„ฃ์–ด์ฃผ๋Š” ๊ฒฝ์šฐ๋„ ์žˆ์ง€๋งŒ (relative position embedding) ํ•ด๋‹น ํฌ์ŠคํŒ…์˜ ๋ฒ”์ฃผ๋ฅผ ๋„˜์–ด์„œ๊ธฐ ๋•Œ๋ฌธ์— ํ–ฅํ›„ ์†Œ๊ฐœํ•˜๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

Note that: ์•„๋ž˜ ์ฝ”๋“œ๋Š” GPT2 script์—์„œ ๋ฐœ์ทŒํ•œ ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค! wte, wpe๋ฅผ ๊ตฌํ•˜๋Š” ๊ฒƒ์€ model by model์ด์—์š”!"

word_sequence = "๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š”"
inputs = tokenizer(word_sequence, return_tensors="pt")
input_ids = inputs["input_ids"]
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input_ids
tensor([[33245, 10114, 12748, 11357]])

์‹ค์ œ๋กœ ๋ชจ๋ธ ์ธํ’‹์— ์–ด๋–ป๊ฒŒ ๋“ค์–ด๊ฐ€๋‚˜ ํ™•์ธํ•ด๋ณด์ฃ .

import inspect

# input_ids, attention_mask, token_type_ids, position_ids๊ฐ€ ์ค‘์š”ํ•ด์š”
# forward์™€ __call__์˜ ๊ด€๊ณ„๋Š” `torch.nn.Module`์„ ์ƒ์†๋ฐ›์•„์„œ ๊ทธ๋ž˜์š”
# ์ด๊ฑด ๋‹ค์Œ ํ•™์Šต ๊ธฐํšŒ๋กœ!
inspect.signature(model.transformer.forward).parameters
mappingproxy({'input_ids': <Parameter "input_ids=None">,
              'past_key_values': <Parameter "past_key_values=None">,
              'attention_mask': <Parameter "attention_mask=None">,
              'token_type_ids': <Parameter "token_type_ids=None">,
              'position_ids': <Parameter "position_ids=None">,
              'head_mask': <Parameter "head_mask=None">,
              'inputs_embeds': <Parameter "inputs_embeds=None">,
              'encoder_hidden_states': <Parameter "encoder_hidden_states=None">,
              'encoder_attention_mask': <Parameter "encoder_attention_mask=None">,
              'use_cache': <Parameter "use_cache=None">,
              'output_attentions': <Parameter "output_attentions=None">,
              'output_hidden_states': <Parameter "output_hidden_states=None">,
              'return_dict': <Parameter "return_dict=None">})

์œ„์—์„œ ํ™•์ธํ•œ ๊ฒƒ ์ฒ˜๋Ÿผ position type ids๋ฅผ ์ง์ ‘ ์ž…๋ ฅ์— ๋„ฃ์–ด์ค„ ์ˆ˜๋„ ์žˆ์ง€๋งŒ ์ด๋ฒˆ์—” ์ง์ ‘ ๋งŒ๋“ค์–ด์ค„๊ฒŒ์š”! (์‹ค์ œ๋กœ source code์—์„œ position_ids๊ฐ€ None์ด๋ฉด ์•„๋ž˜์ฒ˜๋Ÿผ ๋งŒ๋“ค์–ด์ค˜์š”)

import torch

position_ids = torch.arange(0, input_shape[-1], dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids # ๋„ค ๊ฐœ์˜ ํ† ํฐ์— ๋Œ€ํ•œ ์œ„์น˜ ์ •๋ณด
tensor([[0, 1, 2, 3]])

Word token embedding์˜ ๊ฒฝ์šฐ vocab์˜ ์ˆ˜๋งŒํผ vector๊ฐ€ ์ •์˜๋˜์–ด ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

ํ•˜์ง€๋งŒ Word position embedding์˜ ๊ฒฝ์šฐ tokenizing์˜ ๊ฒฐ๊ณผ๋กœ ๋‚˜์˜จ ํ† ํฐ์˜ ์ˆ˜๋กœ ๋งคํ•‘์ด ๋˜๊ธฐ ๋•Œ๋ฌธ์— ๋ฏธ๋ฆฌ max_length๋ฅผ ์ •ํ•ด๋‘ฌ์š”. KoGPT2์˜ ๊ฒฝ์šฐ์—” 1,024๋„ค์š”!

์œ„์™€ ๊ฐ™์€ ์ด์œ ๋กœ wte์™€ wpe์˜ matrix shape์€ ๋‹ค๋ฆ…๋‹ˆ๋‹ค!

  • GPT2๋Š” absolute position embedding์„ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์ด์—์š”!
  • Transformer์˜ SInusoidal encoding์„ ์‚ฌ์šฉํ•˜๋ฉด extrapolate๋ฅผ ํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ €๋Ÿฐ ์œ„์น˜ ๊ณ ์ • ๋ฌธ์ œ๋Š” ์ƒ๊ธฐ์ง€ ์•Š๊ฒ ์ฃ !
(
    model.transformer.wte, # vocab_size X hidden_dim
    model.transformer.wpe, # max_position_length X hidden_dim
)
(Embedding(51200, 768), Embedding(1024, 768))
inputs_embeds = model.transformer.wte(input_ids)
position_embeds = model.transformer.wpe(position_ids)

hidden_states = inputs_embeds + position_embeds
hidden_states.shape # (batch_size, sequence length, hidden_dim)
torch.Size([1, 4, 768])

(2) Transformer Layers

  • Self-Attention
  • Feed-Forward Network
  • ๊ธฐํƒ€ ๋ชจ๋“ˆ ๋“ฑ
print(f"n_layers: {len(model.transformer.h)}")
for i, block in enumerate(model.transformer.h):
    outputs = block(hidden_states)
    hidden_states = outputs[0]
hidden_states = model.transformer.ln_f(hidden_states) # final layer norm
n_layers: 12
hidden_states.shape
torch.Size([1, 4, 768])

(3) Language Model Head

  • Transformer layer๋“ค์„ ํ†ต๊ณผํ•˜์—ฌ ๋‚˜์˜จ hidden_states๋ฅผ ๊ฐ ํ† ํฐ ๋ณ„ ํ™•๋ฅ ๋กœ ๋งคํ•‘
lm_logits = model.lm_head(hidden_states)
lm_logits.shape # (batch_size, sequence_length, vocab_size)
torch.Size([1, 4, 51200])

์ด๋ ‡๊ฒŒ ์„ธ ๊ฐ€์ง€ ๊ณผ์ •์„ ๊ฑฐ์ณ์„œ ๋ชจ๋ธ์€ Causal-LM, ์ด์ „ ๋‹จ์–ด๋“ค๋กœ๋ถ€ํ„ฐ ๋‹ค์Œ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•˜๋Š” Conditional next word distribution์„ ํ•™์Šตํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์ถ”๋ก ์—์„œ๋Š” ์ด์ œ๋ถ€ํ„ฐ ์†Œ๊ฐœํ•  decoding ๋ฐฉ๋ฒ•๋ก ์œผ๋กœ ๊ณ„์‚ฐ๋œ ํ™•๋ฅ ์„ ์–ด๋–ป๊ฒŒ ์‚ฌ์šฉํ•˜๋Š๋ƒ ์ด๊ฒƒ์ด ๊ฐˆ๋ฆฌ๊ฒ ์ง€์š”!

Maximization

Greedy search๋Š” ๋‹จ์ˆœํ•˜๊ฒŒ ๋‹ค์Œ ๋‹จ์–ด๋กœ ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง€๋Š” ๋‹จ์–ด๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค. ์ˆ˜์‹์œผ๋กœ ์ด๋ฅผ ๋‹ค์‹œ ์“ฐ๋ฉด,

$$w_t=\underset{w}{\mathrm{argmax}}{P(w|w_{1:t-1})}$$

  • $W_0$๋Š” ์ƒ๋žต๋œ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

์ด๋ฅผ ์ด๋ฏธ์ง€๋กœ ๊ทธ๋ ค๋ณด๋ฉด ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

img

The ๋ผ๋Š” ๋‹จ์–ด๋กœ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜์—ฌ ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ํƒ์š•์ ์œผ๋กœ(greedily) ๋‹ค์Œ์œผ๋กœ ์˜ฌ ๋‹จ์–ด๋กœ ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง€๋Š” nice๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ์ข…๋ฃŒ ์‹œ์ ๊นŒ์ง€ ํƒ์š•์ ์œผ๋กœ ์„ ํƒํ•˜๋ฉด ์œ„์˜ ์˜ˆ์‹œ์—์„œ ์ตœ์ข…์ ์œผ๋กœ ์ƒ์„ฑ๋œ word sequence๋Š” (The, nice, woman)์ด ๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

  • ํ•ด๋‹น word sequence๊ฐ€ ๋‚˜์˜ฌ ํ™•๋ฅ ์€ $0.5 \times 0.4 = 0.2$๋กœ ๊ฝค๋‚˜ ๋†’์Šต๋‹ˆ๋‹ค.

๊ตฌํ˜„ ์ƒ์„ธ๋Š” ์ œ๊ฐ€ ๋œฏ์–ด๋ณด๋ฉฐ ์•Œ์•„๋‚ธ ์ •๋ณด๋ฅผ ๋‹ค๋ฃจ๋Š” Chapter 2์—์„œ ๋‹ค๋ฃจ๊ณ  huggingface์—์„œ ์–ด๋–ป๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š”์ง€ ์•Œ์•„๋ด…์‹œ๋‹ค.

input_ids = tokenizer.encode("๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š”", return_tensors="pt")

# CLM์œผ๋กœ ๋ฌธ์žฅ์„ ์ƒ์„ฑ (output length๊ฐ€ 128์— ๋„๋‹ฌํ•  ๋•Œ ๊นŒ์ง€)
greedy_output = model.generate(input_ids, max_length=128)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋ฌด์—‡๋ณด๋‹ค ๊ทœ์น™์ ์ธ ์ƒํ™œ์Šต๊ด€์ด ์ค‘์š”ํ•˜๋‹ค.
ํŠนํžˆ, ์•„์นจ์‹์‚ฌ๋Š” ๋‹จ๋ฐฑ์งˆ๊ณผ ๋น„ํƒ€๋ฏผ, ๋ฌด๊ธฐ์งˆ ๋“ฑ ์˜์–‘์†Œ๊ฐ€ ํ’๋ถ€ํ•œ ์Œ์‹์„ ๊ณจ๊ณ ๋ฃจ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.
๋˜ํ•œ ๊ทœ์น™์ ์ธ ์šด๋™์€ ๊ทผ์œก์„ ๊ฐ•ํ™”์‹œ์ผœ์ฃผ๋Š” ํšจ๊ณผ๊ฐ€ ์žˆ๋‹ค.
ํŠนํžˆ, ์•„์นจ์‹์‚ฌ๋Š” ๋‹จ๋ฐฑ์งˆ๊ณผ ๋น„ํƒ€๋ฏผ, ๋ฌด๊ธฐ์งˆ ๋“ฑ ์˜์–‘์†Œ๊ฐ€ ํ’๋ถ€ํ•œ ์Œ์‹์„ ๊ณจ๊ณ ๋ฃจ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.
๋˜ํ•œ ๊ทœ์น™์ ์ธ ์šด๋™์€ ๊ทผ์œก์„ ๊ฐ•ํ™”์‹œ์ผœ์ฃผ๋Š” ํšจ๊ณผ๊ฐ€ ์žˆ๋‹ค.
ํŠนํžˆ, ์•„์นจ์‹์‚ฌ๋Š” ๋‹จ๋ฐฑ์งˆ๊ณผ ๋น„ํƒ€๋ฏผ, ๋ฌด๊ธฐ์งˆ ๋“ฑ ์˜์–‘์†Œ๊ฐ€ ํ’๋ถ€ํ•œ ์Œ์‹์„ ๊ณจ๊ณ ๋ฃจ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.
๋˜ํ•œ ๊ทœ์น™์ ์ธ ์šด๋™์€ ๊ทผ์œก์„ ๊ฐ•ํ™”์‹œ์ผœ์ฃผ๋Š” ํšจ๊ณผ๊ฐ€ ์žˆ๋‹ค.
๊ทผ์œก์„ ๊ฐ•ํ™”์‹œ์ผœ์ฃผ๋Š” ์šด๋™์€ ๊ทผ์œก์„ ๊ฐ•ํ™”์‹œ์ผœ์ฃผ๋Š” ํšจ๊ณผ๊ฐ€ ์žˆ๋‹ค.
๊ทผ์œก์„ ๊ฐ•ํ™”์‹œ์ผœ์ฃผ๋Š” ์šด๋™์€ ๊ทผ์œก์„ ๊ฐ•ํ™”

์˜ค... ์ž˜ ์ƒ์„ฑํ•ด๋ƒˆ๊ตฐ์š” ใ…Žใ…Ž. ํ•˜์ง€๋งŒ ์ž์„ธํžˆ ๋ณด๋ฉด ๊ทœ์น™์ ์ธ ์ƒํ™œ์Šต๊ด€์ด ์ค‘์š”ํ•˜๋‹ค๊ณ  ๋‚ด์šฉ์„ ๋ฐ˜๋ณตํ•˜๋Š” ๋ฌธ์ œ๊ฐ€ ๋ณด์ด๋Š”๊ตฐ์š”...!

์ด๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ natural language generation์—์„œ ์ผ๋ฐ˜์ ์ธ ๋ฌธ์ œ์ด๋ฉฐ greedy search, beam search์™€ ๊ฐ™์€ maximization ๊ธฐ๋ฒ•์—์„œ ํ›จ์”ฌ ๋” ์‹ฌํ•˜๊ฒŒ ๋ฐœ์ƒ๋ฉ๋‹ˆ๋‹ค.

Greedy search์˜ ์ฃผ๋œ ๋‹จ์ ์€ ๋‚ฎ์€ ํ™•๋ฅ  ์‚ฌ์ด์— ์ˆจ๊ฒจ์ง„ ๋†’์€ ํ™•๋ฅ ์˜ ๋‹จ์–ด๋ฅผ ๋†“์น˜๋Š” ๊ฒƒ ์ž…๋‹ˆ๋‹ค. ์œ„์˜ ์˜ˆ์‹œ์—์„œ๋„ (The, dog, has)๋ฅผ ๋†“์ณค์ฃ . ์ด ๋ฌธ์žฅ์€ ์‚ฌ์‹ค $0.4 \times 0.9 = 0.36$์œผ๋กœ ์œ„์˜ ๋ฌธ์žฅ๋ณด๋‹ค ์กฐ๊ธˆ ๋” ๊ฐ€๋Šฅ์„ฑ์ด ์žˆ๋Š” ๋ฌธ์žฅ์ž…๋‹ˆ๋‹ค.

๊ณ ๋ง™๊ฒŒ๋„ beam search๊ฐ€ ์œ„ ๋ฌธ์ œ๋ฅผ ์กฐ๊ธˆ ๋œ์–ด์ค๋‹ˆ๋‹ค!

  • alleviate์ž…๋‹ˆ๋‹ค. ํ•ด๊ฒฐํ•ด์ฃผ์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค...

Beam search๋Š” ๊ฐ time step์—์„œ ๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์€ ๊ฐ€์„ค์„ num_beams๋งŒํผ ์œ ์ง€ํ•˜๊ณ  ๊ฒฐ๊ตญ ์ „์ฒด ํ™•๋ฅ ์ด ๊ฐ€์žฅ ๋†’์€ ๊ฐ€์„ค(hypothesis)๋ฅผ ์„ ํƒํ•˜์—ฌ ์ˆจ๊ฒจ์ง„ ๋†’์€ ํ™•๋ฅ ์˜ word sequence๋ฅผ ๋†“์น  ์œ„ํ—˜์„ ์ค„์ž…๋‹ˆ๋‹ค.

์•„๋ž˜ ์˜ˆ์‹œ๋Š” num_beams๊ฐ€ 2์ผ ๋•Œ beam search์˜ ๋™์ž‘ ๊ณผ์ •์ž…๋‹ˆ๋‹ค.

img

Time step 1์—์„œ ๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์€ ๊ฐ€์„ค์€ (The, nice)๋กœ ํ™•๋ฅ ์ด 0.5, ๊ทธ ๋‹ค์Œ์œผ๋กœ ๋†’์€ ํ™•๋ฅ ์„ ๋ณด์ด๋Š” ๊ฐ€์„ค์€ (The, dog)๋กœ ํ™•๋ฅ ์ด 0.4์ž…๋‹ˆ๋‹ค. greedy search์—์„œ๋Š” top-1๋งŒ ์„ ํƒํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ์•„๋ž˜ ๊ฐ€์„ค์ด ๋ฌด์‹œ๋˜์—ˆ์ง€๋งŒ beam search์—์„œ๋Š” num_beams๋งŒํผ ๊ฐ€์„ค์„ ์œ ์ง€ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋‘ ๋ฒˆ์งธ๋กœ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง€๋Š” ๊ฐ€์„ค (The, dog)๋ฅผ ๊ธฐ๊ฐํ•˜์ง€ ์•Š๊ณ  ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค. (์–ด๋–ป๊ฒŒ ์œ ์ง€ํ•˜๋Š”์ง€๋Š” Ch2)

Time step 2์—์„œ (The, dog, has)๋Š” 0.36์˜ ํ™•๋ฅ ์„ ๊ฐ€์ง€๋Š” ๊ฐ€์„ค์ด๊ณ  greedy search์˜ ๊ฒฐ๊ณผ์˜€๋˜ (The, nice, woman)์€ 0.2์˜ ํ™•๋ฅ ์„ ๊ฐ€์ง€๋Š” ๊ฐ€์„ค์ž…๋‹ˆ๋‹ค. ์–ด๋•Œ์š” ํ™•๋ฅ ์ด ๋’ค์ง‘ํ˜”์ฃ ?

Beam search๋Š” ํ•ญ์ƒ greedy search๋ณด๋‹ค ๋†’์€ ํ™•๋ฅ ๋กœ output sequence๋ฅผ ์ฐพ์•„์ค๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ์ด ์žˆ๋Š” ์ถœ๋ ฅ์„ ์ฐพ๋Š” ๊ฒƒ์€ ๋ณด์žฅ๋˜์ง€ ์•Š์ฃ . (sub-optimal solution)

transformers์—์„œ์˜ ์˜ˆ์‹œ๋ฅผ ๋ด…์‹œ๋‹ค!

beam_output = model.generate(
    input_ids, 
    max_length=128, 
    num_beams=5, 
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.
์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.
ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.
ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.
ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ

์Œ... ํ•˜์ง€๋งŒ ์ถœ๋ ฅ์— ์—ฌ์ „ํžˆ ๋™์ผํ•œ word sequence์˜ ๋ฐ˜๋ณต์ด ํฌํ•จ๋˜๋Š” ๊ตฐ์š”...

์ด์— ๋Œ€ํ•œ ๊ฐ„๋‹จํ•œ ํ•ด๊ฒฐ์ฑ…์€ Paulus ์—ฐ๊ตฌ์ง„์ด ๋„์ž…ํ•œ n-grams penalty(a.k.a word sequences of n words)๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๊ฐ€์žฅ ํ”ํ•œ n-grams penalty๋Š” ์ด๋ฏธ ๋ณธ n-gram์„ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋Š” ๋‹ค์Œ ๋‹จ์–ด์˜ ํ™•๋ฅ ์„ ์ˆ˜๋™์œผ๋กœ 0์œผ๋กœ ์„ค์ •ํ•˜์—ฌ n-gram์ด ๋‘ ๋ฒˆ ๋‹ค์‹œ ๋‚˜ํƒ€๋‚˜์ง€ ์•Š๋„๋ก ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

no_repeat_ngram_size๋ฅผ 2๋กœ ์„ค์ •ํ•˜์—ฌ ๋™์ผํ•œ n-gram์ด 2๋ฒˆ ์ด์ƒ ๋ฐ˜๋ณต๋˜์ง€ ์•Š๋„๋ก ์ˆ˜์ •ํ•ด๋ณด์ฃ !

beam_output = model.generate(
    input_ids, 
    max_length=128, 
    num_beams=5,
    no_repeat_ngram_size=2,
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๋Š” ๊ฒƒ์ด ๊ฐ€์žฅ ์ค‘์š”ํ•˜๋‹ค.
์ฝœ๋ผ๊ฒ์€ ํ”ผ๋ถ€์˜ ํƒ„๋ ฅ์„ ์œ ์ง€ํ•˜๋Š” ๋ฐ ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ”ผ๋ถ€ ๋…ธํ™”๋ฅผ ์˜ˆ๋ฐฉํ•˜๊ณ  ํƒ„๋ ฅ ์žˆ๋Š” ํ”ผ๋ถ€๋กœ ๊ฐ€๊ฟ”์ฃผ๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.
๋˜ํ•œ ํ”ผ๋ถ€ ํƒ„๋ ฅ์ด ๋–จ์–ด์ง€๊ธฐ ์‰ฌ์šด ๊ฒจ์šธ์ฒ ์—๋Š” ๋ณด์Šต๊ณผ ์˜์–‘์„ ๋™์‹œ์— ์ฑ™๊ธธ ์ˆ˜ ์žˆ๋Š” ์ œํ’ˆ์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.
๊ฒจ์šธ์ฒ ์—๋Š” ํ”ผ๋ถ€๊ฐ€ ๊ฑด์กฐํ•ด์ง€๊ธฐ ์‰ฌ์šฐ๋ฏ€๋กœ ์ถฉ๋ถ„ํ•œ ์ˆ˜๋ถ„์„ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค. ํ˜„๋Œ€์ž๋™์ฐจ(ํšŒ์žฅ ์ •๋ชฝ๊ตฌ)๋Š” ์ง€๋‚œ๋‹ฌ ๊ตญ๋‚ด ์ž๋™์ฐจ ํŒ๋งค๋Ÿ‰์ด ์ „๋…„ ๋™์›” ๋Œ€๋น„ 4.2% ๊ฐ์†Œํ•œ 1๋งŒ2์ฒœ511๋Œ€๋ฅผ ๊ธฐ๋กํ–ˆ๋‹ค๊ณ  1์ผ ๋ฐํ˜”๋‹ค.
ํ˜„๋Œ€์ฐจ๋Š” ์ง€๋‚œํ•ด 12์›” ๊ตญ๋‚ด ์‹œ์žฅ์—์„œ

๋ฐ˜๋ณต์˜ ๋ฌธ์ œ๋Š” ํ•ด๊ฒฐ๋˜์—ˆ๊ตฐ์š”! (๋‹ค๋งŒ ๊ฐ‘๋ถ„ ํ˜„๋Œ€์ž๋™์ฐจ... ์ €๋Š” ๋‹ค๋งŒ ๊ทผ์œก์ด ์ปค์ง€๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ๊ณ  ์‹ถ์—ˆ์Šต๋‹ˆ๋‹ค๋งŒ...)

๋‹ค๋งŒ patricks์— ๋”ฐ๋ฅด๋ฉด n-gram penalty๋Š” ์ฃผ์˜ํ•ด์„œ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด New York ์‹œ์— ๋Œ€ํ•ด ์ƒ์„ฑ๋œ ๊ธฐ์‚ฌ๋Š” 2-gram penalty๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ „์ฒด ํ…์ŠคํŠธ์—์„œ ๋„์‹œ ์ด๋ฆ„์ด ํ•œ ๋ฒˆ๋งŒ ๋‚˜ํƒ€๋‚˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

Beam search์˜ ๋˜ ๋‹ค๋ฅธ ์ค‘์š”ํ•œ ๊ธฐ๋Šฅ์€ ์ƒ์„ฑ ํ›„ top beams๋ฅผ ๋น„๊ตํ•˜๊ณ  ๋ชฉ์ ์— ๊ฐ€์žฅ ์ž˜ ๋งž๋Š” generated beam์„ ์„ ํƒํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์ž…๋‹ˆ๋‹ค.

๐Ÿค— transformers์—์„œ num_return_sequences ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์„ธํŒ…ํ•˜๋ฉด ์œ„์˜ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • num_return_sequences๋Š” ํ•ญ์ƒ num_beams๋ณด๋‹ค ์ž‘์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.
beam_outputs = model.generate(
    input_ids, 
    max_length=128, 
    num_beams=5,
    no_repeat_ngram_size=2,
    num_return_sequences=5,
    early_stopping=True
)

# now we have 3 output sequences
print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
    decoded_text = tokenizer.decode(beam_output, skip_special_tokens=True)
    print(f"{i}: {decoded_text}", end="\n\n")
Output:
----------------------------------------------------------------------------------------------------
0: ๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๋Š” ๊ฒƒ์ด ๊ฐ€์žฅ ์ค‘์š”ํ•˜๋‹ค.
์ฝœ๋ผ๊ฒ์€ ํ”ผ๋ถ€์˜ ํƒ„๋ ฅ์„ ์œ ์ง€ํ•˜๋Š” ๋ฐ ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ”ผ๋ถ€ ๋…ธํ™”๋ฅผ ์˜ˆ๋ฐฉํ•˜๊ณ  ํƒ„๋ ฅ ์žˆ๋Š” ํ”ผ๋ถ€๋กœ ๊ฐ€๊ฟ”์ฃผ๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.
๋˜ํ•œ ํ”ผ๋ถ€ ํƒ„๋ ฅ์ด ๋–จ์–ด์ง€๊ธฐ ์‰ฌ์šด ๊ฒจ์šธ์ฒ ์—๋Š” ๋ณด์Šต๊ณผ ์˜์–‘์„ ๋™์‹œ์— ์ฑ™๊ธธ ์ˆ˜ ์žˆ๋Š” ์ œํ’ˆ์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.
๊ฒจ์šธ์ฒ ์—๋Š” ํ”ผ๋ถ€๊ฐ€ ๊ฑด์กฐํ•ด์ง€๊ธฐ ์‰ฌ์šฐ๋ฏ€๋กœ ์ถฉ๋ถ„ํ•œ ์ˆ˜๋ถ„์„ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค. ํ˜„๋Œ€์ž๋™์ฐจ(ํšŒ์žฅ ์ •๋ชฝ๊ตฌ)๋Š” ์ง€๋‚œ๋‹ฌ ๊ตญ๋‚ด ์ž๋™์ฐจ ํŒ๋งค๋Ÿ‰์ด ์ „๋…„ ๋™์›” ๋Œ€๋น„ 4.2% ๊ฐ์†Œํ•œ 1๋งŒ2์ฒœ511๋Œ€๋ฅผ ๊ธฐ๋กํ–ˆ๋‹ค๊ณ  1์ผ ๋ฐํ˜”๋‹ค.
ํ˜„๋Œ€์ฐจ๋Š” ์ง€๋‚œํ•ด 12์›” ๊ตญ๋‚ด ์‹œ์žฅ์—์„œ

1: ๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๋Š” ๊ฒƒ์ด ๊ฐ€์žฅ ์ค‘์š”ํ•˜๋‹ค.
์ฝœ๋ผ๊ฒ์€ ํ”ผ๋ถ€์˜ ํƒ„๋ ฅ์„ ์œ ์ง€ํ•˜๋Š” ๋ฐ ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ”ผ๋ถ€ ๋…ธํ™”๋ฅผ ์˜ˆ๋ฐฉํ•˜๊ณ  ํƒ„๋ ฅ ์žˆ๋Š” ํ”ผ๋ถ€๋กœ ๊ฐ€๊ฟ”์ฃผ๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.
๋˜ํ•œ ํ”ผ๋ถ€ ํƒ„๋ ฅ์ด ๋–จ์–ด์ง€๊ธฐ ์‰ฌ์šด ๊ฒจ์šธ์ฒ ์—๋Š” ๋ณด์Šต๊ณผ ์˜์–‘์„ ๋™์‹œ์— ์ฑ™๊ธธ ์ˆ˜ ์žˆ๋Š” ์ œํ’ˆ์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.
๊ฒจ์šธ์ฒ ์—๋Š” ํ”ผ๋ถ€๊ฐ€ ๊ฑด์กฐํ•ด์ง€๊ธฐ ์‰ฌ์šฐ๋ฏ€๋กœ ์ถฉ๋ถ„ํ•œ ์ˆ˜๋ถ„์„ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค. ํ˜„๋Œ€์ž๋™์ฐจ(ํšŒ์žฅ ์ •๋ชฝ๊ตฌ)๋Š” ์ง€๋‚œ๋‹ฌ ๊ตญ๋‚ด ์ž๋™์ฐจ ํŒ๋งค๋Ÿ‰์ด ์ „๋…„ ๋™์›” ๋Œ€๋น„ 4.2% ๊ฐ์†Œํ•œ 1๋งŒ2์ฒœ567๋Œ€๋ฅผ ๊ธฐ๋กํ–ˆ๋‹ค๊ณ  1์ผ ๋ฐํ˜”๋‹ค.
ํ˜„๋Œ€์ฐจ๋Š” ์ง€๋‚œํ•ด 12์›” ๊ตญ๋‚ด ์‹œ์žฅ์—์„œ

2: ๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๋Š” ๊ฒƒ์ด ๊ฐ€์žฅ ์ค‘์š”ํ•˜๋‹ค.
์ฝœ๋ผ๊ฒ์€ ํ”ผ๋ถ€์˜ ํƒ„๋ ฅ์„ ์œ ์ง€ํ•˜๋Š” ๋ฐ ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ”ผ๋ถ€ ๋…ธํ™”๋ฅผ ์˜ˆ๋ฐฉํ•˜๊ณ  ํƒ„๋ ฅ ์žˆ๋Š” ํ”ผ๋ถ€๋กœ ๊ฐ€๊ฟ”์ฃผ๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.
๋˜ํ•œ ํ”ผ๋ถ€ ํƒ„๋ ฅ์ด ๋–จ์–ด์ง€๊ธฐ ์‰ฌ์šด ๊ฒจ์šธ์ฒ ์—๋Š” ๋ณด์Šต๊ณผ ์˜์–‘์„ ๋™์‹œ์— ์ฑ™๊ธธ ์ˆ˜ ์žˆ๋Š” ์ œํ’ˆ์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.
๊ฒจ์šธ์ฒ ์—๋Š” ํ”ผ๋ถ€๊ฐ€ ๊ฑด์กฐํ•ด์ง€๊ธฐ ์‰ฌ์šฐ๋ฏ€๋กœ ์ถฉ๋ถ„ํ•œ ์ˆ˜๋ถ„์„ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค. ํ˜„๋Œ€์ž๋™์ฐจ(ํšŒ์žฅ ์ •๋ชฝ๊ตฌ)๋Š” ์ง€๋‚œ๋‹ฌ ๊ตญ๋‚ด ์ž๋™์ฐจ ํŒ๋งค๋Ÿ‰์ด ์ „๋…„ ๋™์›” ๋Œ€๋น„ 4.2% ๊ฐ์†Œํ•œ 1๋งŒ2์ฒœ567๋Œ€๋ฅผ ๊ธฐ๋กํ–ˆ๋‹ค๊ณ  1์ผ ๋ฐํ˜”๋‹ค.
ํ˜„๋Œ€์ฐจ๋Š” ์ง€๋‚œํ•ด 12์›” ๋‚ด์ˆ˜ ํŒ๋งค

3: ๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๋Š” ๊ฒƒ์ด ๊ฐ€์žฅ ์ค‘์š”ํ•˜๋‹ค.
์ฝœ๋ผ๊ฒ์€ ํ”ผ๋ถ€์˜ ํƒ„๋ ฅ์„ ์œ ์ง€ํ•˜๋Š” ๋ฐ ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ”ผ๋ถ€ ๋…ธํ™”๋ฅผ ์˜ˆ๋ฐฉํ•˜๊ณ  ํƒ„๋ ฅ ์žˆ๋Š” ํ”ผ๋ถ€๋กœ ๊ฐ€๊ฟ”์ฃผ๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.
๋˜ํ•œ ํ”ผ๋ถ€ ํƒ„๋ ฅ์ด ๋–จ์–ด์ง€๊ธฐ ์‰ฌ์šด ๊ฒจ์šธ์ฒ ์—๋Š” ๋ณด์Šต๊ณผ ์˜์–‘์„ ๋™์‹œ์— ์ฑ™๊ธธ ์ˆ˜ ์žˆ๋Š” ์ œํ’ˆ์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.
๊ฒจ์šธ์ฒ ์—๋Š” ํ”ผ๋ถ€๊ฐ€ ๊ฑด์กฐํ•ด์ง€๊ธฐ ์‰ฌ์šฐ๋ฏ€๋กœ ์ถฉ๋ถ„ํ•œ ์ˆ˜๋ถ„์„ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค. ํ˜„๋Œ€์ž๋™์ฐจ(ํšŒ์žฅ ์ •๋ชฝ๊ตฌ)๋Š” ์ง€๋‚œ๋‹ฌ ๊ตญ๋‚ด ์ž๋™์ฐจ ํŒ๋งค๋Ÿ‰์ด ์ „๋…„ ๋™์›” ๋Œ€๋น„ 4.2% ๊ฐ์†Œํ•œ 1๋งŒ2์ฒœ511๋Œ€๋ฅผ ๊ธฐ๋กํ–ˆ๋‹ค๊ณ  1์ผ ๋ฐํ˜”๋‹ค.
ํ˜„๋Œ€์ฐจ๋Š” ์ง€๋‚œํ•ด 12์›” ๊ตญ๋‚ด ํŒ๋งค

4: ๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ”ผ๋ถ€ ์† ์ฝœ๋ผ๊ฒ๊ณผ ์—˜๋ผ์Šคํ‹ด์˜ ์ƒ์„ฑ์„ ์ด‰์ง„์‹œํ‚ค๋Š” ๊ฒƒ์ด ๊ฐ€์žฅ ์ค‘์š”ํ•˜๋‹ค.
์ฝœ๋ผ๊ฒ์€ ํ”ผ๋ถ€์˜ ํƒ„๋ ฅ์„ ์œ ์ง€ํ•˜๋Š” ๋ฐ ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ”ผ๋ถ€ ๋…ธํ™”๋ฅผ ์˜ˆ๋ฐฉํ•˜๊ณ  ํƒ„๋ ฅ ์žˆ๋Š” ํ”ผ๋ถ€๋กœ ๊ฐ€๊ฟ”์ฃผ๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.
๋˜ํ•œ ํ”ผ๋ถ€ ํƒ„๋ ฅ์ด ๋–จ์–ด์ง€๊ธฐ ์‰ฌ์šด ๊ฒจ์šธ์ฒ ์—๋Š” ๋ณด์Šต๊ณผ ์˜์–‘์„ ๋™์‹œ์— ์ฑ™๊ธธ ์ˆ˜ ์žˆ๋Š” ์ œํ’ˆ์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.
๊ฒจ์šธ์ฒ ์—๋Š” ํ”ผ๋ถ€๊ฐ€ ๊ฑด์กฐํ•ด์ง€๊ธฐ ์‰ฌ์šฐ๋ฏ€๋กœ ์ถฉ๋ถ„ํ•œ ์ˆ˜๋ถ„์„ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค. ํ˜„๋Œ€์ž๋™์ฐจ(ํšŒ์žฅ ์ •๋ชฝ๊ตฌ)๋Š” ์ง€๋‚œ๋‹ฌ ๊ตญ๋‚ด ์ž๋™์ฐจ ํŒ๋งค๋Ÿ‰์ด ์ „๋…„ ๋™์›” ๋Œ€๋น„ 4.2% ๊ฐ์†Œํ•œ 1๋งŒ2์ฒœ511๋Œ€๋ฅผ ๊ธฐ๋กํ–ˆ๋‹ค๊ณ  1์ผ ๋ฐํ˜”๋‹ค.
ํ˜„๋Œ€์ฐจ๋Š” ์ง€๋‚œํ•ด 12์›” ๋‚ด์ˆ˜ ํŒ๋งค

์Œ... ๋ฐ˜ํ™˜๋ฐ›์•˜์ง€๋งŒ ๊ฐ beam๋“ค์ด ํฌ๊ฒŒ ๋‹ค๋ฅด์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

Open-ended generation์—์„œ ์ตœ๊ทผ์— beam search๊ฐ€ ์ตœ์„ ์ด ์•„๋‹ ์ˆ˜ ์žˆ๋‹ค๋Š” ๋ช‡ ๊ฐ€์ง€ ์ด์œ ๊ฐ€ ์ œ๊ธฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

  • Beam search๋Š” ๊ธฐ๊ณ„ ๋ฒˆ์—ญ์ด๋‚˜ ์š”์•ฝ๊ฐ™์ด ์›ํ•˜๋Š” ์ƒ์„ฑ์˜ ๊ธธ์ด๊ฐ€ ์–ด๋Š ์ •๋„ ์˜ˆ์ธก ๊ฐ€๋Šฅํ•œ ์ž‘์—…์—์„œ ๋งค์šฐ ์ž˜ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์›ํ•˜๋Š” ์ถœ๋ ฅ์˜ ๊ธธ์ด๊ฐ€ ํฌ๊ฒŒ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ๋Š” open-ended generation์˜ ๊ฒฝ์šฐ(๋Œ€ํ™” ํ˜น์€ story ์ƒ์„ฑ) ๊ทธ๋ ‡์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
  • ์œ„์—์„œ ํ™•์ธํ–ˆ๋“ฏ beam search๋Š” repetitive generation์— ์ทจ์•ฝํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” n-gram ํ˜น์€ ๋‹ค๋ฅธ penalty๋กœ ์ ์ ˆํžˆ ์กฐ์ •ํ•˜๊ธฐ๊ฐ€ ์–ด๋ ต์Šต๋‹ˆ๋‹ค.
  • Holtzman์— ๋”ฐ๋ฅด๋ฉด, High quality human language๋Š” ๋‹ค์Œ ๋‹จ์–ด๊ฐ€ ๊ฐ€์žฅ ๋†’๊ฒŒ ์˜ฌ ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ๋”ฐ๋ฅด์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์ฆ‰ ์ธ๊ฐ„์€ ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ๊ฐ€ ์šฐ๋ฆฌ๋ฅผ ๋†€๋ผ๊ฒŒ ํ•˜๊ณ (surprise) ์ง€๋ฃจํ•˜๊ฑฐ๋‚˜(boring) ์˜ˆ์ธกํ•  ์ˆ˜ ์—†๊ธฐ๋ฅผ(not to be predictable) ์›ํ•ฉ๋‹ˆ๋‹ค. ์ €์ž๋Š” BeamSearch๋กœ ์ƒ์„ฑ๋œ text๊ฐ€ ๋œ ๋†€๋ž๋‹ค๋Š” ๊ฒƒ์„ ์•„๋ž˜ plot์œผ๋กœ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

img

์ž, ์ง€๋ฃจํ•œ text๋Š” ๊ทธ๋งŒ ์ƒ์„ฑํ•˜๊ณ  randomness๋ฅผ ๋„์ž…ํ•ฉ์‹œ๋‹ค :)

Sampling

Temperature-Sampling

๊ฐ€์žฅ ๊ธฐ๋ณธ์ ์ธ ํ˜•ํƒœ์˜ sampling์€ ๋‹ค์Œ ๋‹จ์–ด $w_t$๋ฅผ conditional probability distribution์— ๋”ฐ๋ผ ์ž„์˜๋กœ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

$$w_t\sim P(w|w_{1:t-1})$$

์•„๋ž˜ ์‹œ๊ฐํ™”๋กœ sampling ์‹œ ์–ธ์–ด ์ƒ์„ฑ์— ๋Œ€ํ•ด ์•Œ์•„๋ด…์‹œ๋‹ค.

img

์ƒ˜ํ”Œ๋ง์„ ์‚ฌ์šฉํ•œ ์–ธ์–ด ์ƒ์„ฑ์€ ๋” ์ด์ƒ ๊ฒฐ์ •์ ์ด์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

  • Maximization ๊ธฐ๋ฒ•(greedy, beam)์€ ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง€๋Š” ๊ฐ€์„ค๋งŒ์„ ํƒํ–ˆ์Šต๋‹ˆ๋‹ค.

๋‹จ์–ด car์€ beam์„ 3๋งŒํผ ๋Š˜๋ฆฌ์ง€ ์•Š์œผ๋ฉด ์ด์ „ maximization ๊ธฐ๋ฒ•์—์„œ๋Š” ์–ด๋– ํ•œ ๊ฒฝ์šฐ์—๋„ ์ ˆ๋Œ€๋กœ ์ฑ„ํƒ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ •๋ง ๋‚ฎ์€ ํ™•๋ฅ ๋กœ ์กฐ๊ฑด๋ถ€ ํ™•๋ฅ  ๋ถ„ํฌ $P(w|`the`)$์—์„œ ๋‹จ์–ด car๊ฐ€ ์ถ”์ถœ๋  ์ˆ˜ ์žˆ์œผ๋ฉฐ ๋‹ค์Œ ๋‹จ์–ด์ธ drives, is, turns๊ฐ€ ์กฐ๊ฑด๋ถ€ ํ™•๋ฅ  ๋ถ„ํฌ $P(w|`the`,`car)$์—์„œ ์ถ”์ถœ๋  ๊ฒƒ ์ž…๋‹ˆ๋‹ค.

๐Ÿค— transformers์—์„œ do_sample ์˜ต์…˜์„ ํ™œ์„ฑํ™”์‹œํ‚ค๊ณ  top-k sampling์„ ๋น„ํ™œ์„ฑํ™”์‹œ์ผœ์„œ ์œ„๋ฅผ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

import random
import torch
import numpy as np

def set_seed(seed: int = 42):
    """Seed fixer (random, numpy, torch)
    Args:
        seed (:obj:`int`): The seed to set.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
set_seed()
sample_output = model.generate(
    input_ids,
    do_sample=True,
    max_length=128,
    top_k=0,
)


print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ์ผ๋‹จ ๊ต์ •์— ํž˜์จ์•ผ ํ•œ๋‹ค.
์—ฌ๊ธฐ์„œ ์ค‘์š”ํ•œ ๊ฒƒ์€ ๊ต์ •์„ ์ž˜๋ชปํ•˜์˜€์„ ๋•Œ ๊ทธ ๊ต์ •์ด ์ž˜๋ชป ๋˜์—ˆ๋Š”์ง€ ํ•œ๋ฒˆ ๋งํ•ด์ฃผ๋ฉด ์Šค์Šค๋กœ๊ฐ€ ๊ต์ •๋˜๋Š” ๊ฒƒ์ด๋‹ค.
๊ต์ •์ˆ˜์ˆ ์˜ ๊ฐ€์žฅ ํฐ ์›์น™์€ ๊ต์ •์ด ์–ด๋ ต๋‹ค๊ณ  ํ•˜์—ฌ ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.
์‹ฌํ•œ ๊ฒฝ์šฐ์—๋Š” ๊ต์ •์—๋งŒ ์ „๋…ํ•˜๋Š” ๊ฒƒ์ด ์ตœ์„ ์˜ ๋ฐฉ๋ฒ•์ด๋‹ค.
ํŠนํžˆ ๊ต์ •์— ์‹œ๊ฐ„์„ ํˆฌ์žํ•˜๋‹ค
๋ณด๋ฉด ๊ต์ •์ˆ˜์ˆ ์— ์š•์‹ฌ์ด ์ƒ๊ธฐ๊ฒŒ ๋˜๋Š” ์ฆ‰์ด ์žˆ๋‹ค.
๋ฌผ๋ก  ์ด ๋ฐฉ๋ฒ•์€ ๊ต์ •์ด ๊ธฐ๊ณ„์ ์ž„์„ ์ด์šฉํ•˜์—ฌ ์ง€์†์ ์ธ ๊ต์ •์ˆ˜์ˆ ์„ ํ•˜๋Š” ๋ฐฉ๋ฒ•์ด์ง€๋งŒ, ๊ต์ •์„ ์ž˜ ํ•˜๋ฉด ์˜คํžˆ๋ ค ๊ธฐ์กด์˜ ๊ต์ •์ˆ˜์ˆ ์— ๋น„ํ•ด ๊ต์ •๋ ฅ์ด ๋” ๊ฐ•ํ•ด์งˆ ์ˆ˜ ์žˆ๋‹ค.
๋˜ํ•œ ์ตœ์„ ์˜ ๊ต์ •์น˜๋ฃŒ๋ฅผ ํ•˜๋ฉด ๊ฐœ์„ ๋œ๋‹ค.
์ฆ‰ ๊ต์ •์€ ๊ต์ •์ด ์ง„ํ–‰๋ ์ˆ˜๋ก ๊ฒฐ๊ตญ์€ ๊ฐœ์„ ๋œ๋‹ค.
๊ต

์–ด... ๋‚ด์šฉ์€ ์ฐจ์น˜ํ•˜๊ณ  ๋ฐ˜๋ณต ๋ฌธ์ œ๋Š” ์•ˆ๋ณด์ด๋„ค์š”! ํ•˜์ง€๋งŒ ํ‘œํ˜„์ด ์ด์ƒํ•ฉ๋‹ˆ๋‹ค.

  • ๊ต์ •์ด ์–ด๋ ต๋‹ค๊ณ  ํ•˜์—ฌ ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.
  • ํŠนํžˆ ๊ต์ •์— ์‹œ๊ฐ„์„ ํˆฌ์žํ•˜๋‹ค
  • ๋ณด๋ฉด ๊ต์ •์ˆ˜์ˆ ์— ์š•์‹ฌ์ด ์ƒ์‹œ๊ฒŒ ๋˜๋Š” ์ฆ‰์ด ์žˆ๋‹ค

์ด๋Š” word sequence๋ฅผ samplingํ•  ๋•Œ ์ƒ๊ธฐ๋Š” ํฐ ๋ฌธ์ œ์ž…๋‹ˆ๋‹ค. ๋ชจ๋ธ์€ ์ข…์ข… ์ผ๊ด€์„ฑ์—†์ด ํšก์„ค์ˆ˜์„คํ•ฉ๋‹ˆ๋‹ค.

์œ„๋ฅผ ํ•ด๊ฒฐํ•  ํŠธ๋ฆญ์€ ๋ถ„ํฌ $P(w|w_{1:t-1})$์„ sharpํ•˜๊ฒŒ ๋งŒ๋“œ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

  • ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง€๋Š” ๋‹จ์–ด์˜ likelihood๋ฅผ ๋†’์ด๊ณ 
  • ๊ฐ€์žฅ ๋‚ฎ์€ ํ™•๋ฅ ์„ ๊ฐ€์ง€๋Š” ๋‹จ์–ด์˜ likelihood๋ฅผ ๋‚ฎ์ถ”๋Š” ๊ฒƒ

์œ„ ํŠธ๋ฆญ์€ softmax์˜ temperature๋ผ๊ณ  ๋ถˆ๋ฆฝ๋‹ˆ๋‹ค.

temperature๋ฅผ ์ ์šฉํ•œ ์˜ˆ์‹œ์— ๋Œ€ํ•œ ์‹œ๊ฐํ™”์ž…๋‹ˆ๋‹ค.

img

temperature๋ฅผ ์ ์šฉํ•˜๊ธฐ ์ „์—๋Š” $P(w|`the`)$์—์„œ car๊ฐ€ ๋ฝ‘ํž ํ™•๋ฅ ์ด 0.1์ด์—ˆ์ง€๋งŒ ์ง€๊ธˆ์€ 0.02์ž…๋‹ˆ๋‹ค. ๋‚ฎ์•„์ง„ ๋งŒํผ ๋ฝ‘ํžˆ๊ธฐ๋Š” ๋” ํž˜๋“ค๊ฒ ์ฃ ?

sample_output = model.generate(
    input_ids,
    do_sample=True,
    max_length=128,
    top_k=0,
    temperature=0.7,
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ์นผ์Š˜๊ณผ ๋งˆ๊ทธ๋„ค์Š˜์ด ํ’๋ถ€ํ•œ ์Œ์‹์„ ๊พธ์ค€ํžˆ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.
๋˜ํ•œ ๊ทœ์น™์ ์ธ ์šด๋™, ์ŠคํŠธ๋ ˆ์Šค ์œ ๋ฐœ๊ณผ ๊ฐ™์€ ์ƒํ™œ์Šต๊ด€ ๊ด€๋ฆฌ์—๋„ ์‹ ๊ฒฝ์„ ์จ์•ผ ํ•œ๋‹ค.
ํ•œ์–‘๋Œ€๋ณ‘์› ๊ฐ€์ •์˜ํ•™๊ณผ ์–‘ํ˜•์ฒ  ๊ต์ˆ˜๋Š” โ€œ์ง€๋ฐฉ๊ฐ„ ์งˆํ™˜์„ ๊ฐœ์„ ํ•˜๊ณ  ์‚ถ์˜ ์งˆ์„ ๋†’์ด๋Š” ๋ฐ ๋„์›€์ด ๋˜๋Š” ๋น„ํƒ€๋ฏผD์™€ ์นผ์Š˜์„ ๋งŽ์ด ์„ญ์ทจํ•˜๋ฉด ๊ฐ„ ๊ฑด๊ฐ•์— ๋„์›€์ด ๋  ์ˆ˜ ์žˆ๋‹คโ€๋ฉฐ ๋น„ํƒ€๋ฏผD, ์นผ์Š˜, ๋งˆ๊ทธ๋„ค์Š˜์ด ํ’๋ถ€ํ•œ ์Œ์‹์„ ๊พธ์ค€ํžˆ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ๊ฐ„ ๊ฑด๊ฐ•์— ๋„์›€์ด ๋œ๋‹ค๊ณ  ์กฐ์–ธํ–ˆ๋‹ค. ์„œ์šธ์‹œ๋Š” ์˜ค๋Š” 10์ผ ์˜คํ›„ 5์‹œ30๋ถ„ ๋งˆํฌ๊ตฌ ์„œ๊ต๋™ ํ™์ต๋Œ€ ์ธ๊ทผ ์„ ์œ ๋„๊ณต์›์—์„œ 'ํ”Œ๋ผ์›Œ ํŽ˜์Šคํ‹ฐ๋ฒŒ-์„ ์œ ๋„๊ณต์›์„ ์ฐพ์•„๋ผ' ํ–‰์‚ฌ๋ฅผ ๊ฐœ์ตœํ•œ๋‹ค๊ณ  9์ผ

Maximization์˜ ๊ฒฐ๊ณผ์™€ ์œ ์‚ฌํ•˜๋ฉด์„œ ๋ฐ˜๋ณต์€ ์•ˆํ•˜๊ณ  ๋‹ค๋ฅธ ๋‚ด์šฉ๊นŒ์ง€ ์ถ”๊ฐ€๋˜์—ˆ์Šต๋‹ˆ๋‹ค! (๋ฌผ๋ก  ๊ทผ์œก๊ณผ๋Š” ์•„์ง๋„ ๊ด€๋ จ์ด ์ ์Šต๋‹ˆ๋‹ค... ๊ทธ๋ž˜๋„ ์ผ๊ด€์„ฑ์€ ๊ฐœ์„ ๋˜์—ˆ๊ตฐ์š”.)

temperature๋ฅผ ์ ์šฉํ•˜๋ฉด ๋ถ„ํฌ๋ฅผ ๋œ randomํ•˜๊ฒŒ ๋งŒ๋“ค ์ˆ˜ ์žˆ์ง€๋งŒ 0์œผ๋กœ ์„ค์ •ํ•˜๋ฉด greedy decoding๊ณผ ๋™์ผํ•ด์ง‘๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ์ด์ „๊ณผ ๊ฐ™์€ ๋ฌธ์ œ๋ฅผ ๋‹ค์‹œ ๊ฒช๊ฒŒ ๋˜๊ฒ ์ง€์š”.

Top-K Sampling

Fan ์—ฐ๊ตฌ์ง„์€ ์•„์ฃผ ๊ฐ„๋‹จํ•˜์ง€๋งŒ ๊ฐ•๋ ฅํ•œ sampling scheme์ธ Top-K๋ฅผ ์†Œ๊ฐœํ–ˆ์Šต๋‹ˆ๋‹ค.

Top-K sampling์—์„œ ๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์€ K๊ฐœ์˜ ๋‹ค์Œ ๋‹จ์–ด๋Š” filtering๋˜๊ณ  probability mass๋Š” K๊ฐœ์˜ ๋‹ค์Œ ๋‹จ์–ด์— ๋Œ€ํ•ด์„œ๋งŒ ์žฌ๋ถ„๋ฐฐ๋ฉ๋‹ˆ๋‹ค. GPT2๊ฐ€ ์ด sampling scheme๋ฅผ ํƒํ–ˆ๊ณ  story generation์—์„œ ์„ฑ๊ณต์ ์ด์—ˆ๋˜ ์›์ธ ์ค‘ ํ•˜๋‚˜๋กœ ํ‰๊ฐ€๋ฉ๋‹ˆ๋‹ค.

  • ์ „์ฒด ๋‹จ์–ด ๋ถ„ํฌ์—์„œ ์ƒ˜ํ”Œ๋งํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ ๊ณ ์ •๋œ ์ƒ์œ„ K๊ฐœ์˜ ๋‹จ์–ด์—์„œ sampling ์ˆ˜ํ–‰

img

Time step 1์—์„œ Top-6๊ฐœ๋ฅผ ์ œ์™ธํ•œ ๋‚˜๋จธ์ง€ people, big, house, cat์€ ์ƒ์„ฑ ๋Œ€์ƒ์—์„œ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. (Top-K๊ฐœ๋งŒ filtering, Vocab์— pick-up)

Step 1์—์„œ๋Š” ์ „์ฒด์˜ 2/3, step 2์—์„œ๋Š” ๊ฑฐ์˜ ๋ชจ๋“  probability mass๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.

Time step 2์—์„œ Top-6๊ฐœ๋ฅผ ์ œ์™ธํ•œ ๋‚˜๋จธ์ง€ not, the, small, told ์ด์ƒํ•œ ๋‹จ์–ด๋“ค์„ ์„ฑ๊ณต์ ์œผ๋กœ ์ œ์™ธํ•˜๊ณ  ์ถ”์ถœํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

sample_output = model.generate(
    input_ids,
    do_sample=True,
    max_length=128,
    top_k=50,
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ์„ฑ์žฅํ˜ธ๋ฅด๋ชฌ์ด ๋งŽ์ด ๋ถ„๋น„๋ผ์•ผ ํ•˜๋Š”๋ฐ, ๋ฐ˜๋Œ€๋กœ ์„ฑ์žฅ์„ธํฌ ์ฃผ์‚ฌ๋ฅผ ๋งž๊ณ  ์ž๋ผ๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋ถ€์กฑํ•œ ์˜์–‘์†Œ๊ฐ€ ํ•„์š”ํ•˜๊ฒŒ ๋œ๋‹ค.
์„ฑ์žฅํ˜ธ๋ฅด๋ชฌ์ด ๋ถ€์กฑํ•˜๋ฉด ํ˜ˆ์•ก์˜ ์›ํ™œํ•œ ํ๋ฆ„์„ ๋•๋Š” ๋น„ํƒ€๋ฏผA๊ฐ€ ํ’๋ถ€ํ•ด์ง€์ง€๋งŒ, ์„ฑ์žฅ์„ธํฌ๊ฐ€ ๋ถ€์กฑํ•˜๋ฉด ํ˜ˆ์•ก์˜ ํ๋ฆ„์ด ์›ํ™œํ•˜์ง€ ์•Š๊ฒŒ ๋ผ ์˜์–‘๊ณต๊ธ‰์ด ์ œ๋Œ€๋กœ ์ด๋ค„์ง€์งˆ ๋ชปํ•œ๋‹ค๋Š” ๊ฒƒ์€ ๋งค์šฐ ์น˜๋ช…์ ์ด๋‹ค.
์„ฑ์ธ์ด๋ผ๋ฉด ๋ณดํ†ต 10~15% ์ •๋„ ์„ฑ์žฅ์ด ์ž˜ ๋˜์ง€๋งŒ, 20, 30๋Œ€ ์ค‘๋…„ ๋‚จ์„ฑ๋“ค๊ณผ ๊ณ ์—ฐ๋ น์ธต์˜ ๊ฒฝ์šฐ, 40, 50๋Œ€ ์ค‘๋…„ ์—ฌ์„ฑ๋“ค๋ณด๋‹ค ๋” ํฐ ์„ฑ์žฅ์ด ํ•„์š”ํ•˜๋‹ค.
์ด๋ฟ ์•„๋‹ˆ๋ผ ์„ฑํ˜ธ๋ฅด๋ชฌ ์ˆ˜์น˜๋Š” ๋” ๋นจ๋ฆฌ ๊ฐ์†Œํ•˜๋Š” ๊ฒฝํ–ฅ์ด ์žˆ๋‹ค๋Š” ์—ฐ๊ตฌ๊ฒฐ๊ณผ๋„ ์žˆ๋‹ค.
์—ฐ๊ตฌํŒ€๋“ค์€ ํ˜ธ๋ฅด๋ชฌ์ด ๋ถ€์กฑํ•˜๋ฉด ํ˜ˆ์•ก ๊ณต๊ธ‰์ด ๋ถ€์กฑํ•˜๊ณ , ํ˜ˆ์•ก ๋‚ด

์ œ์ผ ๊ดœ์ฐฎ์€ ๊ฒฐ๊ณผ์ธ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค! ๊ฐ€์žฅ ์ธ๊ฐ„๊ฐ™์ด ์ƒ์„ฑ๋œ ๊ฒƒ ๊ฐ™๊ตฐ์š”. Top-K sampling์˜ ํ•œ ๊ฐ€์ง€ ๋ฌธ์ œ๋Š” next word distribution $P(w|w_{1:t-1})$์—์„œ filtering๋˜๋Š” ๋‹จ์–ด์˜ ์ˆ˜๋ฅผ dynamicํ•˜๊ฒŒ ์ ์šฉํ•˜์ง€ ์•Š๋Š” ๋‹ค๋Š” ์ ์ž…๋‹ˆ๋‹ค.

  • ๊ณ ์ •๋œ K๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ์— ๋ฌธ์ œ

์œ„ ๊ทธ๋ž˜ํ”„์—์„œ ์˜ค๋ฅธ์ชฝ์˜ ๊ฒฝ์šฐ ๋งค์šฐ sharpํ•œ ๋ถ„ํฌ์—์„œ sampling๋˜์ง€๋งŒ ์™ผ์ชฝ์˜ ๊ฒฝ์šฐ์—๋Š” ๋” flatํ•œ ๋ถ„ํฌ์—์„œ sampling๋˜๊ธฐ์— ๋ฌธ์ œ๊ฐ€ ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Step 1์—์„œ Top-K๋Š” people, big, house, cat ์™€ ๊ฐ™์€ ๊ฐ€๋Šฅ์„ฑ์žˆ๋Š” ํ›„๋ณด๊ตฐ๋“ค์„ ์ œ์™ธํ–ˆ์Šต๋‹ˆ๋‹ค. ๋ฐ˜๋Œ€๋กœ Step 2์—์„œ๋Š” ๋‹จ์–ด์˜ sample pool(In top-k)์— ๋ถ€์ ํ•ฉํ•œ ๋‹จ์–ด down, a๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค. ๋•Œ๋ฌธ์— sample pool์€ ๊ณ ์ •๋œ ํฌ๊ธฐ K๋กœ ์ œํ•œํ•˜๋Š” ๊ฒƒ์€ ๋ชจ๋ธ์ด sharp distribution์— ๋Œ€ํ•ด ํšก์„ค์ˆ˜์„ค(gibberish)ํ•  ์œ„ํ—˜์ด ์žˆ๊ณ  flat distribution์— ๋Œ€ํ•ด ์ฐฝ์˜์„ฑ(creativity)์ด ์ œํ•œ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์œ„ ์ง๊ด€์ด Ari Holtzman ์—ฐ๊ตฌ์ง„๋“ค์ด ์ œ์•ˆํ•œ Top-p ํ˜น์€ nucleus sampling์œผ๋กœ ์ด์–ด์ง‘๋‹ˆ๋‹ค.

Top-p (nucleus) Sampling

๊ฐ€์žฅ ๋†’์€ K๊ฐœ์˜ ๋‹จ์–ด๋ฅผ ์„ ํƒํ•˜๋Š” ๋Œ€์‹  Top-P sampling์€ ๋ˆ„์  ํ™•๋ฅ ์ด ํ™•๋ฅ  p๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ๊ฐ€๋Šฅํ•œ ๊ฐ€์žฅ ์ž‘์€ ๋‹จ์–ด ์ง‘ํ•ฉ์—์„œ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ probability mass๋Š” ์ด ๋‹จ์–ด set ์‚ฌ์ด์— ์žฌ๋ถ„๋ฐฐ๋ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฐ ์‹์œผ๋กœ ๋‹จ์–ด ์ง‘ํ•ฉ์˜ ํฌ๊ธฐ(a.k.a the number of words in the set)์€ ๋‹ค์Œ ๋‹จ์–ด์˜ ํ™•๋ฅ  ๋ถ„ํฌ์— ๋”ฐ๋ผ ๋™์ ์œผ๋กœ ์ฆ๊ฐ€ํ•˜๊ฑฐ๋‚˜ ๊ฐ์†Œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์œ„๋ฅผ ์‹œ๊ฐํ™”ํ•ด๋ด…์‹œ๋‹ค!

img

$p=0.92$๋กœ ์„ค์ •ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. Top-p sampling์€ probability mass์˜ 92%๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ๋‹จ์–ด์˜ minimum number๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผํ…Œ๋ฉด ์œ„์—์„œ cat์„ ์ œ์™ธํ•œ ๋‹จ์–ด์˜ prob mass์˜ ํ•ฉ์€ 0.94๋กœ ์„ค์ •ํ•œ p๋ณด๋‹ค ์ปค์ง€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์ฆ‰ time step 1์—์„œ๋Š” 9๊ฐœ์˜ ๋‹จ์–ด๋ฅผ ๊ณ ๋ฅด๊ณ  time step 2์—์„œ๋Š” drives, is, turns ๋งŒ์œผ๋กœ๋„ 97%์ž…๋‹ˆ๋‹ค. ๋•Œ๋ฌธ์— 3๊ฐœ์˜ ๋‹จ์–ด๋กœ ๊ณ ์ • ํ›„ sampling์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. Top-K์—์„œ ๊ณ ์ •์ ์ธ K๋กœ samplingํ•œ ๊ฒƒ๊ณผ ๋‹ค๋ฅด๊ฒŒ Top-p์—์„œ๋Š” next word distribution์— ๋”ฐ๋ผ dynamicํ•˜๊ฒŒ sampling pool์„ ๊ฒฐ์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

sample_output = model.generate(
    input_ids,
    do_sample=True,
    max_length=128,
    top_p=0.92,
    top_k=0,
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋ฌผ ์†์— ์˜ค๋ž˜ ์žˆ๋‹ค๊ฐ€ ๋งฅ์ฃผ๋ฅผ ์‚ผํ‚ค๋Š” ๊ฒƒ์ด ๋„์›€์ด ๋œ๋‹ค.
๋ชธ์ด ์ฐจ๊ฐ€์›Œ์ง€๋ฉด ์•„๋“œ๋ ˆ๋‚ ๋ฆฐ์ด ๋ถ„๋น„๋˜์–ด ๋†๋„๊ฐ€ ๋†’์•„์ง€๋ฏ€๋กœ, ๋ชธ์— ์Œ“์ธ ์•„๋“œ๋ ˆ๋‚ ๋ฆฐ ๋ถ„๋น„๋ฅผ ์กฐ์ ˆํ•ด์•ผ ํ•œ๋‹ค.
๋”ฐ๋ผ์„œ 8000mg์—์„œ 90mg ์ •๋„ ๋จน๋Š” ๊ฒƒ์ด ๋ฐ”๋žŒ์งํ•˜๋‹ค.
๋ฐ• ๊ต์ˆ˜๋Š” โ€œ์ตœ๊ทผ ์œ ํ–‰ํ•˜๋Š” ํ”„๋ฆฌ๋ฐ”์ด์˜คํ‹ฑ์Šค๋Š” ์žฅ๋‚ด ๋ฏธ์ƒ๋ฌผ์˜ ์ฆ์‹ ๋“ฑ ๊ณ ์œ ์˜ ํŠน์„ฑ์„ ๊ฐ–๊ณ  ์žˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์žฅ์—๋„ ๋งŽ์€ ์ข…๋ฅ˜์˜ ์œ ์‚ฐ๊ท ์ด ์ฆ์‹ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋‹ค์–‘ํ•˜๊ธด ํ•˜์ง€๋งŒ ์žฅ์— ์œ ์ตํ•œ ๊ท ์ด ๋งŽ์ด ์ƒ์„ฑ๋˜์ง€ ์•Š๋Š” ํŽธโ€์ด๋ผ๊ณ  ๋ฐํ˜”๋‹ค.
์ •๋‹ต์€ โ€˜๊ณต๋ถ€ ๋น„๋ฒ•โ€™์ด๋‹ค.
๊ณต๋ถ€์— ๋น ์ง„ ์•„์ด๋“ค์˜ ์„ธํฌ๋ง‰์„ ํ˜„๋ฏธ๊ฒฝ์œผ๋กœ ๋“ค์—ฌ๋‹ค๋ณธ๋‹ค.
์•„๊ธฐ์ฒ˜๋Ÿผ ํฌ๊ณ  ๊ฑด๊ฐ•ํ•œ ๋ฒ ๊ฐœ๋ฅผ ๋งŒ๋“ 

์ข‹์Šต๋‹ˆ๋‹ค! ๋งจ ์ฒ˜์Œ samplingํ–ˆ์„ ๊ฒฐ๊ณผ๋ณด๋‹ค๋Š” ํ›จ์”ฌ ๋” ์‚ฌ๋žŒ๋‹ค์›Œ ์กŒ์Šต๋‹ˆ๋‹ค. (๋‚ด์šฉ์€...)

์ด๋ก ์ ์œผ๋กœ Top-p๋Š” Top-k๋ณด๋‹ค ๋” ์šฐ์•„ํ•ด๋ณด์ด์ง€๋งŒ ์‹ค์ œ๋กœ๋Š” ๋‘ ๋ฐฉ๋ฒ• ๋ชจ๋‘ ์ž˜ ๋™์ž‘ํ•˜๊ณ  Top-p๋Š” Top-k์™€ ํ•จ๊ป˜ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. Top-K๋Š” ๋งค์šฐ ๋‚ฎ์€ ์ˆœ์œ„์˜ ๋‹จ์–ด๋ฅผ ํ”ผํ•˜๋ฉด์„œ ์ผ๋ถ€ ๋™์  ์„ ํƒ์„ ํ—ˆ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋งˆ์ง€๋ง‰์œผ๋กœ ๋…๋ฆฝ์ ์œผ๋กœ ์ƒ˜ํ”Œ๋ง๋œ ์—ฌ๋Ÿฌ ์ถœ๋ ฅ์„ ์–ป๊ธฐ ์œ„ํ•ด ๋งค๊ฐœ๋ณ€์ˆ˜ num_return_sequences๋ฅผ 1๋ณด๋‹ค ํฌ๊ฒŒ ๋‹ค์‹œ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

sample_outputs = model.generate(
    input_ids,
    do_sample=True,
    max_length=128,
    top_p=0.95,
    top_k=50,
    num_return_sequences=3,
)

print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
    decoded_text = tokenizer.decode(sample_output, skip_special_tokens=True)
    print(f"{i}: {decoded_text}", end="\n\n")
Output:
----------------------------------------------------------------------------------------------------
0: ๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋น„ํƒ€๋ฏผ C๊ฐ€ ๋งค์šฐ ์ค‘์š”ํ•œ๋ฐ, ์ด ๋น„ํƒ€๋ฏผC๊ฐ€ ๋‡Œ๋ฅผ ๊ฑด๊ฐ•ํ•˜๊ฒŒ ๋งŒ๋“œ๋Š” ๋ฐ ํ•„์š”ํ•œ ํ•ต์‹ฌ ์˜์–‘์†Œ๊ฐ€ ๋˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
๋˜ํ•œ ์•„์—ฐ์ด ๋งŽ์ด ํ•จ์œ ๋œ ์Œ์‹์„ ์„ญ์ทจํ•  ๊ฒฝ์šฐ ์‹ฌ์žฅ๋ณ‘์ด ์œ„ํ—˜ํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ์Œ์‹์„ ์”น๋Š” ํšŸ์ˆ˜๋ฅผ ์ค„์ด๋„๋ก ํ•˜๊ณ  ๋น„ํƒ€๋ฏผ C๊ฐ€ ๋งŽ์ด ํ•จ์œ ๋œ ์‹ํ’ˆ์„ ๊พธ์ค€ํžˆ ์„ญ์ทจํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.
๋˜ํ•œ ๋น„ํƒ€๋ฏผ C๋ฅผ ๋งŽ์ด ํ•จ์œ ํ•œ ์‹ํ’ˆ์€ ์ŠคํŠธ๋ ˆ์Šค๋ฅผ ํ•ด์†Œํ•˜๊ณ  ๋‹ค์ด์–ดํŠธ์—๋„ ๋„์›€์ด ๋˜๋ฉฐ, ๋‡Œ์กธ์ค‘์„ ์˜ˆ๋ฐฉํ•˜๋Š” ํšจ๊ณผ๋„ ๊ธฐ๋Œ€ํ•  ์ˆ˜ ์žˆ๋‹ค.
์ด๋Ÿฐ ์ด์œ ๋กœ ์ตœ๊ทผ์—๋Š” โ€˜๊ฑด๊ฐ•ํ•œ ์šฐ๋ฆฌ ๋ชธโ€™์—์„œ ํƒ„์ˆ˜ํ™”๋ฌผ๊ณผ ์ง€๋ฐฉ, ๋‹จ๋ฐฑ์งˆ๊ณผ ๋น„ํƒ€๋ฏผ์„ ์กฐํ•ฉํ•˜๋Š” ๋ฐœํšจ์‹ํ’ˆ์ด ๊ด€์‹ฌ์„ ๋ฐ›๊ณ  ์žˆ๋‹ค.
๋‹น๊ทผ๊ณผ ์ƒ๊ฐ•์€ ํ˜ˆ์ค‘ ๋น„ํƒ€๋ฏผ์„ ๋ณด์ถฉํ•ด ์‹ฌํ˜ˆ๊ด€๊ณ„์— ๋„์›€์„

1: ๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋จผ์ € ๋ชธ์— ๋ฌด๋ฆฌ๊ฐ€ ๊ฐ€์ง€ ์•Š๋„๋ก ํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค. LG์ „์ž์˜ ์ƒˆ๋กœ์šด ๋””์ž์ธ ์ฝ˜์…‰ํŠธ์ธ '๋งค์ง์ŠคํŽ˜์ด์Šค'๋Š” ํ˜์‹ ์ ์ธ ๋””์ž์ธ์„ ๋„˜์–ด ์ œํ’ˆ ์„ฑ๋Šฅ์ด ์—…๊ทธ๋ ˆ์ด๋“œ ๋œ ๊ฒƒ์„ ์˜๋ฏธํ•œ๋‹ค.
LG์ „์ž๋Š” ์˜ฌํ•ด ์„ ๋ณด์ธ ์˜ฌ๋ ˆ๋“œ TV์˜ 'ํ”„๋ฆฌ๋ฏธ์—„' ์ฝ˜์…‰ํŠธ์ธ '๋งค์ง์ŠคํŽ˜์ด์Šค'๋ฅผ ํ†ตํ•ด LG๋งŒ์˜ '3์„ธ๋Œ€'(4K) OLED(์œ ๊ธฐ๋ฐœ๊ด‘๋‹ค์ด์˜ค๋“œ), '์Šค๋งˆํŠธ OLED(์œ ๊ธฐ๋ฐœ๊ด‘๋‹ค์ด์˜ค๋“œ)' ๋“ฑ์œผ๋กœ ํ™•๋Œ€ํ•ด ํ”„๋ฆฌ๋ฏธ์—„ ๋ผ์ธ์—…์˜ ๊ฒฝ์Ÿ๋ ฅ์„ ๊ฐ•ํ™”ํ•ด ๋‚˜๊ฐ„๋‹ค๋Š” ๋ฐฉ์นจ์ด๋‹ค.
LG์ „์ž๋Š” ํŠนํžˆ OLED์˜ ๊ฒฝ์šฐ

2: ๊ทผ์œก์ด ์ปค์ง€๊ธฐ ์œ„ํ•ด์„œ๋Š” ์šฐ์„  ํ”ผ๋ถ€ ์ž์ฒด์˜ ๊ท ํ˜•์ด ์ค‘์š”ํ•˜๋‹ค.
ํŠนํžˆ ๊ฑด์กฐํ•œ ํ”ผ๋ถ€์˜ ๊ฒฝ์šฐ ์ˆ˜๋ถ„ ๊ณต๊ธ‰์— ๋Œ€ํ•œ ์ ์ ˆํ•œ ๊ด€๋ฆฌ๋งŒ์ด ํ”ผ๋ถ€ ํ†ค์„ ๊ฑด๊ฐ•ํ•˜๊ฒŒ ๋˜์ฐพ์•„ ์ฃผ๋Š” ์ตœ์„ ์˜ ๋ฐฉ๋ฒ•์ด๋‹ค.
๋”ํŽ˜์ด์Šค์ƒต์˜ โ€˜๋”ํŽ˜์ด์Šค์ƒต ์ˆ˜๋ถ„ ํฌ๋ฆผโ€™์€ ์ˆ˜๋ถ„ํฌ๋ฆผ์œผ๋กœ ์‚ฌ์šฉ ์‹œ ์ˆ˜๋ถ„์ด ๋”์šฑ ๊ฐ•ํ•˜๊ฒŒ ํก์ˆ˜๋ผ ํ”ผ๋ถ€ ๋ณธ์—ฐ์˜ ํ”ผ๋ถ€ ๋ณด์Šต์— ๋„์›€์„ ์ค„ ์ˆ˜ ์žˆ๋‹ค.
ํŠนํžˆ ํ”ผ๋ถ€์˜ ์ˆ˜๋ถ„ํ•จ์œ ๋Ÿ‰์— ๋”ฐ๋ผ์„œ ๋‹ค์–‘ํ•œ ์ œํ’ˆ์ด ๊ฐœ๋ฐœ๋˜๋ฉฐ ์ˆ˜๋ถ„ํฌ๋ฆผ์„ ๋ฐ”๋ฅธ ํ›„ ํ”ผ๋ถ€ํ†ค๊นŒ์ง€ ํ™˜ํ•˜๊ฒŒ ๋ฐํ˜€์ฃผ๊ณ  ์žˆ๋‹ค.
๋ฐ”๋”” ์ „์šฉ ์•ฐํ”Œ ํƒ€์ž…์˜ ๊ณ ๋ณด์Šต ์ˆ˜๋ถ„ํฌ๋ฆผ์€ ์ž์™ธ์„ ์ฐจ๋‹จ ๊ธฐ๋Šฅ์ด ์žˆ์–ด ์—ฌ๋ฆ„์ฒ  ์•ผ์™ธํ™œ๋™์ด ๋งŽ์€ ํ™˜์ ˆ๊ธฐ์— ์ ํ•ฉํ•˜๋‹ค.
ํŠนํžˆ ์—ฌ๋ฆ„์ฒ  ๋ฏผ๊ฐํ”ผ๋ถ€์—๋Š” ๋ฐ”๋ฅด๋Š” ์ฆ‰์‹œ ํ”ผ๋ถ€์— ๋ณด์Šต๋ง‰์„ ํ˜•์„ฑํ•ด ์ฃผ๋Š”

Conclusion

  • top-p, top-k sampling์€ open-ended language generation์—์„œ ๊ธฐ์กด์˜ greedy-and beam search๋ณด๋‹ค ๋” ์œ ์ฐฝํ•œ text๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์ž„
  • ์ตœ๊ทผ์— greedy ๋ฐ beam search์˜ ๋ช…๋ฐฑํ•œ ๊ฒฝํ•จ(์ฃผ๋กœ ๋ฐ˜๋ณต์ ์ธ word sequence ์ƒ์„ฑ)์ด decoding methodology๋ณด๋‹ค๋Š” model(ํŠนํžˆ ๋ชจ๋ธ์ด ํ›ˆ๋ จ๋˜๋Š” ๋ฐฉ์‹)์— ์˜ํ•ด ๋ฐœ์ƒํ•œ๋‹ค๋Š” ์ฆ๊ฑฐ๊ฐ€ ๋” ๋งŽ์ด ์žˆ์Œ.
  • ๋˜ํ•œ top-k ๋ฐ top-p sampling๋„ repetitive word sequence ์ƒ์„ฑ์—์„œ ์ž์œ ๋กญ์ง„ ๋ชปํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์ž„
  • Welleck์˜ 2019 ์—ฐ๊ตฌ์— ์˜ํ•˜๋ฉด ์ €์ž๋Š” ์‚ฌ๋žŒ์˜ ํ‰๊ฐ€์— ๋”ฐ๋ฅด๋ฉด ๋ชจ๋ธ์˜ ํ›ˆ๋ จ ๋ชฉํ‘œ๋ฅผ ์กฐ์ •ํ•  ๋•Œ Beam search๊ฐ€ Top-p sampling๋ณด๋‹ค ๋” ์œ ์ฐฝํ•œ text๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Œ์„ ๋ณด์ž„
  • Open-ended language generation์€ ๋น ๋ฅด๊ฒŒ ๋ฐœ์ „ํ•˜๋Š” ๋ถ„์•ผ์ด๋ฉฐ ์—ฌ๊ธฐ์— ๋ชจ๋“  ๊ฒฝ์šฐ์— ์ ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์ด ์—†๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์Œ. ๋•Œ๋ฌธ์— ํŠน์ • ์‚ฌ์šฉ ์‚ฌ๋ก€์— ๊ฐ€์žฅ ์ ํ•ฉํ•œ ๋ฐฉ๋ฒ•์ด ๋ฌด์—‡์ธ์ง€๋ฅผ ํ™•์ธํ•ด์•ผ ํ•œ๋‹ค.

Appendix

์œ„์—์„œ ์–ธ๊ธ‰ํ•˜์ง€ ์•Š์€ ์ƒ์„ฑ ๋ฉ”์†Œ๋“œ์— ๋Œ€ํ•œ ๋ช‡ ๊ฐ€์ง€ ์ถ”๊ฐ€ ๋งค๊ฐœ๋ณ€์ˆ˜ ์†Œ๊ฐœ

  • min_length: min_lenght์— ๋„๋‹ฌํ•˜๊ธฐ ์ „์— ๋ชจ๋ธ์ด EOS token์„ ์ƒ์„ฑํ•˜์ง€ ์•Š๋„๋ก ๊ฐ•์ œํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Œ
    • ์š”์•ฝ์—์„œ ๋งค์šฐ ์ž์ฃผ ์‚ฌ์šฉ๋˜์ง€๋งŒ ์‚ฌ์šฉ์ž๊ฐ€ ๋” ๊ธด ์ถœ๋ ฅ์„ ์›ํ•  ๊ฒฝ์šฐ ์ผ๋ฐ˜์ ์œผ๋กœ ์œ ์šฉํ•  ์ˆ˜ ์žˆ์Œ
  • repeat_penalty: ์ด๋ฏธ ์ƒ์„ฑ๋˜์—ˆ๊ฑฐ๋‚˜ context์— ์†ํ•˜๋Š” ๋‹จ์–ด์— penalty๋ฅผ ์ ์šฉํ•˜๋Š”๋ฐ ์‚ฌ์šฉ. Keskar et al., (2019)์— ์˜ํ•ด ์ฒ˜์Œ์œผ๋กœ ์†Œ๊ฐœ๋˜์—ˆ์œผ๋ฉฐ Welleck et al., (2019)์˜ training objective๋กœ๋„ ์‚ฌ์šฉ๋จ. ๋ฐ˜๋ณต์„ ๋ฐฉ์ง€ํ•˜๋Š”๋ฐ ๋งค์šฐ ํšจ๊ณผ์ ์ผ ์ˆ˜ ์žˆ์ง€๋งŒ ๋‹ค์–‘ํ•œ ๋ชจ๋ธ ๋ฐ ์‚ฌ์šฉ ์‚ฌ๋ก€์— ๋งค์šฐ ๋ฏผ๊ฐํ•œ ๊ฒƒ์œผ๋กœ ๋ณด์ž„. ํ•ด๋‹น ๋””์Šค์ปค์…˜ ์ฐธ๊ณ .
  • attention_mask: padded token์„ maskํ•˜๋Š”๋ฐ ์‚ฌ์šฉ
  • pad_token_id, bos_token_id, eos_token_id: ๋ชจ๋ธ์— ๊ธฐ๋ณธ์ ์œผ๋กœ ํ•ด๋‹น ํ† ํฐ์ด ์—†๋Š” ๊ฒฝ์šฐ ์‚ฌ์šฉ์ž๋Š” ๋‹ค๋ฅธ token id๋ฅผ ์ˆ˜๋™์œผ๋กœ ์„ ํƒํ•˜์—ฌ ๋‚˜ํƒ€๋‚ผ ์ˆ˜ ์žˆ์Œ.

GenerationMixin ๋œฏ์–ด๋ณด๊ธฐ