본문 바로가기
AI/기술

BERT는 어떻게 학습시킬까? (Raw text 에서 Training Instance 까지)

by ai.forme 2021. 2. 6.
반응형

수많은 NLP Downstream Task에서 SOTA를 달성한 BERT에 대해 알아보자. 본 글에서는 모델의 구조와 성능에 대한 얘기가 아닌, BERT 학습의 전반적인 이야기를 해보고자 한다. 따라서 본 글은 BERT의 모델 구조에 대한 이해를 필요로 한다. 아직 BERT가 무엇인지 모른다면 아래의 여기를 참고하자. 

 

본 글은 NVIDIA/BERT 코드를 읽고 정리한 것이다.


Raw text 에서 Training Instance 까지

 

(아시다시피) BERT는 아래와 같은 구조가 하나의 학습 객체이다. 1편에서는 자연어 (인간의 언어)를 BERT가 학습할 수 있는 아래의 형태로 바꾸는 작업에 대해 알아볼 것이다.

 

 

1. Bytes > Unicode

 

Python 3 에서 기본적으로 urllib 모듈을 통해서 크롤링한 데이터는 byte 형태이다. Linux에서는 기본 encoding이 utf-8이라 특별히 문제가 되지 않지만 window에선 기본 encoding 이 cp949라서  urllib모듈을 통해 가져온 문자열이 byte형태로 변해서 육안으로 확인이 힘들어진다. 이때 byte형태의 데이터를 Decoding 해줌으로써 문자열을 육안으로 확인할 수 있게 된다. 아래는 여기를 크롤링한 결과와 이를 Decoding 한 결과이다. 본 과정을 거쳐야 하는 이유는 많은 경우 언어 모델을 사전 학습시킬 때 인터넷에서 크롤링해서 데이터를 모으기 때문이다. 한국어의 경우 정제된 데이터가 터무니없이 부족한 상황이다.

 

Decoding 전

 

Decoding 후

 

Python 에서 Decoidng 은 아래와 같이 할 수 있다. decode 라는 내장 함수가 있어 상당히 간편하다.

import six

def convert_to_unicode(text):

    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")  # Deocde
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python 3?")
        

 

 

2. Raw Data > Plain Text

 

Raw Data란 텍스트만이 아닌 다른 정보들이나 문자, 기호 등이 섞여 있는 데이터를 의미한다. 이름 그대로 날것이라고 생각하면 된다. Raw Data의 예시로는 아래와 같은 것들 있다.

 

1. 국립국어원 모두의 말뭉치 신문 말뭉치
2. 나무 위키
3. Kowiki

 

형태는 Json, Zip 등으로 다양하다. 날것의 정보들이 그렇듯, 대부분 크기가 매우 커 압축되어 있다. 모델 학습을 위한 데이터 전처리의 첫 번째 단계는 날것의 데이터에서 학습시킬 데이터만을 추출하는 것이다. 이 과정을 파싱 (Parsing) 이라고 하며, 본 글에서는 순수한 텍스트가 되겠다. 유의할 것은 이 과정은 학습의 목적이나 방법에 따라 굉장히 다양하다는 것이다. 똑같은 BERT를 학습할 때도 다른 방법을 취할 수 있으며, 본 글은 그저 필자가 선택한 방법이다. 중요한 것은 방법은 달라도 반드시 이 과정을 거쳐야  한다는 것이다.

 

파싱의 과정은 크게 아래의 세 가지 단계로 구분된다.

 

- 데이터 정제하기

- 필요 없는 부분 삭제하기

- 학습하는 과정에서 필요한 정보 추가하기

 

수학</a></b>에서 <b>상수</b>란 그 값이 변하지 않는 불변량으로, <a href="/wiki/%EB%B3%80%EC%88%98_(%EC%88%98%ED%95%99)" title="변수 (수학)">변수</a>의 반대말이다. <a href="/wiki/%EB%AC%BC%EB%A6%AC_%EC%83%81%EC%88%98" title="물리 상수">물리 상수</a>와는 달리, 수학 상수는 물리적 측정과는 상관없이 정의된다.
</p><p>수학 상수는 대개 <a href="/wiki/%EC%8B%A4%EC%88%98%EC%B2%B4" class="mw-redirect" title="실수체">실수체</a>나 <a href="/wiki/%EB%B3%B5%EC%86%8C%EC%88%98%EC%B2%B4" class="mw-redirect" title="복소수체">복소수체</a>의 원소이다. 우리가 이야기할 수 있는 상수는 (거의 대부분 <a href="/wiki/%EA%B3%84%EC%82%B0_%EA%B0%80%EB%8A%A5%ED%95%9C_%EC%88%98" title="계산 가능한 수">계산 가능</a>한) <a href="/wiki/%EC%A0%95%EC%9D%98%EA%B0%80%EB%8A%A5%ED%95%9C_%EC%88%98" class="mw-redirect" title="정의가능한 수">정의가능한 수</a>이다.
</p><p>특정 수학 상수, 예를 들면 <a href="/wiki/%EA%B3%A8%EB%A1%AC-%EB%94%95%EB%A7%A8_%EC%83%81%EC%88%98" title="골롬-딕맨 상수">골롬-딕맨 상수</a>, <a href="/wiki/%ED%94%84%EB%9E%91%EC%84%B8%EC%A6%88-%EB%A1%9C%EB%B9%88%EC%8A%A8_%EC%83%81%EC%88%98" title="프랑세즈-로빈슨 상수">프랑세즈-로빈슨 상수</a>, <a href="/wiki/%EC%A0%9C%EA%B3%B1%EA%B7%BC_2" title="제곱근 2"><span class="mwe-math-element"><span class="mwe-math-mathml-inline mwe-math-mathml-a11y" style="display: none;"><math xmlns="http://www.w3.org/1998/Math/MathML"  alttext="{\displaystyle {\sqrt {2}}}">
  <semantics>
    <mrow class="MJX-TeXAtom-ORD">
      <mstyle displaystyle="true" scriptlevel="0">
        <mrow class="MJX-TeXAtom-ORD">
          <msqrt>
            <mn>2</mn>
          </msqrt>
        </mrow>
      </mstyle>
    </mrow>
    <annotation encoding="application/x-tex">{\displaystyle {\sqrt {2}}}</annotation>
  </semantics>
</math></span><img src="https://wikimedia.org/api/rest_v1/media/math/render/svg/b4afc1e27d418021bf10898eb44a7f5f315735ff" class="mwe-math-fallback-image-inline" aria-hidden="true" style="vertical-align: -0.671ex; width:3.098ex; height:3.009ex;" alt="{\sqrt {2}}"/></span></a>, <a href="/wiki/%EB%A0%88%EB%B9%84_%EC%83%81%EC%88%98" title="레비 상수">레비 상수</a>와 같은 상수는 다른 수학상수 또는 함수와 약한 상관관계 또는 강한 상관관계를 갖는다.
</p>
<div id="toc" class="toc" role="navigation" aria-labelledby="mw-toc-heading"><input type="checkbox" role="button" id="toctogglecheckbox" class="toctogglecheckbox" style="display:none" /><div class="toctitle" lang="ko" dir="ltr"><h2 id="mw-toc-heading">목차</h2><span class="toctogglespan"><label class="toctogglelabel" for="toctogglecheckbox"></label></span></div>
<ul>
<li class="toclevel-1 tocsection-1"><a href="#수학_상수표"><span class="tocnumber">1</span> <span class="toctext">수학 상수표</span></a></li>
<li class="toclevel-1 tocsection-2"><a href="#관련_상수들"><span class="tocnumber">2</span> <span class="toctext">관련 상수들</span></a>
<ul>
<li class="toclevel-2 tocsection-3"><a href="#e_관련"><span class="tocnumber">2.1</span> <span class="toctext">e 관련</span></a></li>
<li class="toclevel-2 tocsection-4"><a href="#감마함수_관련"><span class="tocnumber">2.2</span> <span class="toctext">감마함수 관련</span></a></li>
<li class="toclevel-2 tocsection-5"><a href="#소수_관련"><span class="tocnumber">2.3</span> <span class="toctext">소수 관련</span></a></li>
<li class="toclevel-2 tocsection-6"><a href="#리만제타함수_관련"><span class="tocnumber">2.4</span> <span class="toctext">리만제타함수 관련</span></a></li>
<li class="toclevel-2 tocsection-7"><a href="#피보나치_수_관련"><span class="tocnumber">2.5</span> <span class="toctext">피보나치 수 관련</span></a></li>
<li class="toclevel-2 tocsection-8"><a href="#특정_대수방정식_관련"><span class="tocnumber">2.6</span> <span class="toctext">특정 대수방정식 관련</span></a></li>
</ul>
</li>
<li class="toclevel-1 tocsection-9"><a href="#기타_상수들"><span class="tocnumber">3</span> <span class="toctext">기타 상수들</span></a></li>
<li class="toclevel-1 tocsection-10"><a href="#같이_보기"><span class="tocnumber">4</span> <span class="toctext">같이 보기</span></a></li>
</ul>
</div>

 

세 단계를 설명하기 위하여 위와 같은 예시를 참고해보자. 아래의 예시는 날것 (Raw Data) 상태이다. Kowiki-수학 상수에서 윗부분을 가져온 것인데, 처음 가져오게 되면 데이터가 굉장히 많은 HTML 태그와 함께 아주 더러운 것을 볼 수 있다. 따라서 해당 데이터에서 자연어만 추출해야 된다.

 

 

 (1) 데이터 정제하기

 

위 데이터에는 많은 HTML 태그들이 포함되어 있다. 이런 기호들은 텍스트에 대한 정보를 담고 있다. 하지만 언어 모델을 학습시킬 때는 해당 기호들이 필요 없다. (필요한 경우도 있겠지만) 따라서 저러한 특수한 기호들을 제거하는 것을 데이터 정제라고 한다. 특수한 기호들을 없애고 나면 아래와 같은 결과를 얻을 수 있다.

 

# Article: 수학 상수
# Type: regular article

수학에서 상수란 그 값이 변하지 않는 불변량으로, 변수의 반대말이다. 물리 상수와는 달리, 수학 상수는 물리적 측정과는 상관없이 정의된다.

수학 상수는 대개 실수체나 복소수체의 원소이다. 우리가 이야기할 수 있는 상수는 (거의 대부분 계산 가능한) 정의가능한 수이다.

특정 수학 상수, 예를 들면 골롬-딕맨 상수, 프랑세즈-로빈슨 상수, {\displaystyle {\sqrt {2}}}{\sqrt {2}}, 레비 상수와 같은 상수는 다른 수학상수 또는 함수와 약한 상관관계 또는 강한 상관관계를 갖는다.


수학 상수표

관련 상수들

  e 관련

  감마함수 관련

  소수 관련

  리만제타함수 관련

  피보나치 수 관련

  특정 대수방정식 관련

기타 상수들

같이 보기

 

 (2) 데이터 변형하기

 

위 (1)의 결과의 위쪽을 보면  # Article, # Type 과 같은 메타 데이터가 포함되었다. (메타 데이터란 해당 데이터에 대한 정보를 말한다) 필자는 위와 같이 메타 데이터를 그대로 학습하는 것은 일반적이지 않다고 판단하여, 이를 변형하고자 한다. (메타 데이터가 모든 글에 있는 것은 아니기 때문이다) 따라서 좀 더 일반적인 형태인, 제목이 글 위에 있는 형식으로 바꾼다. 그러면 아래와 같은 결과를 얻을 수 2. Raw Data > Plain Text 있겠다.

 

수학상수

수학에서 상수란 그 값이 변하지 않는 불변량으로, 변수의 반대말이다. 물리 상수와는 달리, 수학 상수는 물리적 측정과는 상관없이 정의된다.

수학 상수는 대개 실수체나 복소수체의 원소이다. 우리가 이야기할 수 있는 상수는 (거의 대부분 계산 가능한) 정의가능한 수이다.

특정 수학 상수, 예를 들면 골롬-딕맨 상수, 프랑세즈-로빈슨 상수, {\displaystyle {\sqrt {2}}}{\sqrt {2}}, 레비 상수와 같은 상수는 다른 수학상수 또는 함수와 약한 상관관계 또는 강한 상관관계를 갖는다.


수학 상수표

관련 상수들

  e 관련

  감마함수 관련

  소수 관련

  리만제타함수 관련

  피보나치 수 관련

  특정 대수방정식 관련

기타 상수들

같이 보기

 

 (3) 학습하는 과정에서 필요한 정보 추가하기

 

BERT 의 경우 학습할 때 문서 (Document) 와 문장 (Sentence) 가 정확히 구분되어야 한다. 따라서 문서와 문장을 구분하는 기호를 추가해야 한다. 실제로 학습할 때는 위 기호들은 제거된다. 필자는 추가로 챕터 (Chapter) 를 구분하는 기호를 넣어 필요할 경우 사용하려고 한다. 본 과정을 거치면 아래와 같은 최종 Plain Text 를 얻을 수 있겠다.

 

[DOC_SEP]

[CH_SEP]
수학에서 상수란 그 값이 변하지 않는 불변량으로, 변수의 반대말이다.[SEN_SEP]물리 상수와는 달리, 수학 상수는 물리적 측정과는 상관없이 정의된다.[SEN_SEP]

수학 상수는 대개 실수체나 복소수체의 원소이다.[SEN_SEP]우리가 이야기할 수 있는 상수는 (거의 대부분 계산 가능한) 정의가능한 수이다.[SEN_SEP]

특정 수학 상수, 예를 들면 골롬-딕맨 상수, 프랑세즈-로빈슨 상수, {\displaystyle {\sqrt {2}}}{\sqrt {2}}, 레비 상수와 같은 상수는 다른 수학상수 또는 함수와 약한 상관관계 또는 강한 상관관계를 갖는다.[SEN_SEP]

[CH_SEP]
수학 상수표[SEN_SEP]

관련 상수들[SEN_SEP]

  e 관련[SEN_SEP]

  감마함수 관련[SEN_SEP]

  소수 관련[SEN_SEP]

  리만제타함수 관련[SEN_SEP]

  피보나치 수 관련[SEN_SEP]

  특정 대수방정식 관련[SEN_SEP]

기타 상수들[SEN_SEP]

같이 보기[SEN_SEP]

 

 

3. Plain Text > Shard

 

Sharding 이란 여러 Plain Text 를 여러 파일들로 나누는 작업을 의미한다. Sharding 의 목적은 크게 아래와 같다.

 

- 멀티 프로세스로 처리하기 용이

- 데이터를 섞는 효과

 

여기서 데이터를 섞는다는 것은, 여러 출처의 자료들은 섞어서 하나의 Shard (작은 Plain Text) 를 만드는 것을 의미한다. 이렇게 여러 출처의 자료를 섞어야 하는 이유는 언어 모델을 사전 학습시킬 때는 General 함이 유지되어야 하기 때문이다. 만약 데이터를 출처별로 순차적으로 학습한다면 그 데이터에 모델이 과적합 될 수 있기 때문이다. 

 

 

아래의 그림처럼 kowiki.txt / namuwiki.txt / newspaper.txt 가 있을 때, 세 가지 출처의 데이터를 조금씩 섞어서 하나의 ****.txt 를 만들었다.

 

Sharding 전
Sharding 후

 

 

4. Plain Text > Tokens

 

4.1.  Basic Tokenizer

텍스트를 본격 토큰화 하기 전에 아래와 같은 Cleaning 하는 과정을 거친다.

 

 

- 문단 단위 (위의 예시에서 [DOC_SEP]) 로 텍스들을 나눈다.

 

- 유효하지 않은 Character 들을 제거하고 Whitespace를 모두 띄어쓰기로 바꾼다.

import unicodedata
    
    
def clean_text(text):
    """Performs invalid character removal and whitespace cleanup on text."""
    output = []
    for char in text:
        cp = ord(char)
        if cp == 0 or cp == 0xfffd or is_control(char):
            continue
        if is_whitespace(char):
            output.append(" ")
        else:
            output.append(char)
    return "".join(output)
        
 
 def is_whitespace(char):
    """Checks whether `chars` is a whitespace character."""
    # \t, \n, and \r are technically control characters but we treat them
    # as whitespace since they are generally considered as such.
    if char == " " or char == "\t" or char == "\n" or char == "\r":
        return True
    cat = unicodedata.category(char)
    if cat == "Zs":
        return True
    return False


def _is_control(char):
    """Checks whether `chars` is a control character."""
    # These are technically control characters but we count them as whitespace
    # characters.
    if char == "\t" or char == "\n" or char == "\r":
        return False
    cat = unicodedata.category(char)
    if cat.startswith("C"):
        return True
    return False

 

 

- 특수한 언어들을 처리한다.

예시로 중국어의 경우 한자 하나가 한 단어를 뜻하지만 띄어쓰기가 되어있지 않다. 따라서 한자의 양 옆에 빈칸을 넣어준다.

def tokenize_chinese_chars(text):
    """Adds whitespace around any CJK character."""
    
    output = []
    for char in text:
        cp = ord(char)
        if _is_chinese_char(cp):
            output.append(" ")
            output.append(char)
            output.append(" ")
        else:
            output.append(char)
            
    return "".join(output)
    
    
def is_chinese_char(cp):
    """Checks whether CP is the codepoint of a CJK character."""
    # This defines a "chinese character" as anything in the CJK Unicode block:
    #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
    #
    # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
    # despite its name. The modern Korean Hangul alphabet is a different block,
    # as is Japanese Hiragana and Katakana. Those alphabets are used to write
    # space-separated words, so they are not treated specially and handled
    # like the all of the other languages.
    
    if ((0x4E00 <= cp <= 0x9FFF) or
            (0x3400 <= cp <= 0x4DBF) or
            (0x20000 <= cp <= 0x2A6DF) or
            (0x2A700 <= cp <= 0x2B73F) or
            (0x2B740 <= cp <= 0x2B81F) or
            (0x2B820 <= cp <= 0x2CEAF) or
            (0xF900 <= cp <= 0xFAFF) or
            (0x2F800 <= cp <= 0x2FA1F)):
        return True
        
    return False

 

- 문장 부호를 고려하여 문장을 단어로 분리한다.

import unicodedata


def _run_split_on_punc(token):
    """Splits punctuation on a piece of text."""
    if token in self.never_split:  # Don't split
        return [token]
    chars = list(token)
    i = 0
    start_new_word = True
    output = []
    while i < len(chars):
        char = chars[i]
        if is_punctuation(char):
            output.append([char])
            start_new_word = True
        else:
            if start_new_word:
                output.append([])
            start_new_word = False
            output[-1].append(char)
        i += 1
    return ["".join(x) for x in output]
       
       
def is_punctuation(char):
    """Checks whether `chars` is a punctuation character."""
    cp = ord(char)
    # We treat all non-letter/number ASCII as punctuation.
    # Characters such as "^", "$", and "`" are not in the Unicode
    # Punctuation class but we treat them as punctuation anyways, for
    # consistency.
    if (33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126):
        return True
    cat = unicodedata.category(char)
    if cat.startswith("P"):
        return True
    return False


 

 

4.2. Subword Tokenizer (Wordpiece)

 

위 Basic Tokenizer 을 거친 후 Subword Token으로 바꾸는 과정을 거치는데, 이를 Subword Tokenize 라고 한다.  아래의 예시를 참고해보자. Tokenize 에 대해 이해하려면 우선 Vocab이라는 것을 알아야 한다. 언어 모델은 Vocab이라고 하는 Token의 리스트를 가진다. 이는 아래와 같이 생겼다.

 

이용해서
하이라
빛이
즐겁게
변환
세계로
발병
예정이며
비상대책
모두에게
##공간을
##EF
초반에는
것일까
가격으로
##오면
159
가정의
새겨
라스
있듯이
##자면
##って
##전당
돕고
적용되는
우익
##브로
덤으로
주체
##혹한
##시에서
적힌
대학이
정보기술

 

각각은 모두 토큰들이며, 언어 모델은 자신의 Vocab 외에 있는 토큰들은 모두 [UNK] 토큰으로 처리한다. 자신이 모델을 학습시킬 데이터들로 모델의 Vocab을 만들어야 하는데, BERT를 위한 Vocab는 아래와 같이 Hugging Face의 BERT Tokenizer을 사용하여 만들 수 있다. 참고로 BERT의  Vocab 수는 32000개였고, GPT-3의 Vocab 수는 50257 개다.

import argparse
import json
from tokenizers import BertWordPieceTokenizer
from glob import glob

parser = argparse.ArgumentParser()

parser.add_argument('--vocab_size', type=int, default=32000)
parser.add_argument('--limit_alphabet', type=int, default=6000)

args = parser.parse_args()

tokenizer = BertWordPieceTokenizer(
    vocab=None,
    clean_text=True,
    handle_chinese_chars=True,
    strip_accents=False,
    lowercase=False,
    wordpieces_prefix="##"
)

data = glob('/data/nlp/Merged/**training**.txt')  # 텍스트 데이터들

tokenizer.train(
    files=data,
    limit_alphabet=args.limit_alphabet,
    vocab_size=args.vocab_size
)

tokenizer.save("./ch-{}-wpm-{}-pretty".format(args.limit_alphabet, args.vocab_size), True)

vocab_path = "ch-6000-wpm-32000-pretty"
vocab_file = 'vocab.txt'

f = open(vocab_file, 'w', encoding='utf-8')
with open(vocab_path) as json_file:

    json_data = json.load(json_file)
    for item in json_data["model"]["vocab"].keys():
        f.write(item+'\n')

    f.close()

 

그러면 왜 언어 모델들은 Vocab을 가지는 것이고, 텍스트를 Tokenize 하는 것일까? 

 

텍스트가 모델 학습에 사용되기 위해선 Encoding 과 Embedding의 과정을 거쳐야 한다. Encoding은 각 토큰을 독립적인 값으로 바꾸는 것을 의미한다. (각각에 이름을 부여) Embedding에 대한 내용은 여기를 참고하자. BERT 에서의 Embedding은 (2) 편에서 다루겠다. 예를 들어, 언어 모델의 Vocab의 아래와 같이 생겼다고 해보자. 

[UNK]
아버지
##가
방
##에
들
##어
##가
##셨
##다

그러면 아버지가 방에 후다닥 들어가셨다 라는 문장은 아래와 같이 바뀔 수 있다.

아버지가 방에 후다닥 들어가셨다 

=> 아버지 ##가 방 ##에 [UNK] 들 ##어 ##가 ##셨 ##다 (Tokenize)

=>   1    2  3   4    0    5   6   7    8   9 (Encoding)

 

언어 모델이 Vocab을 가지고 텍스트를 Tokenize 하는 이유는 인간의 모든 단어를 Encoding 하는 것이 불가능하기 때문이다. 세상에 존재하는 모든 단어들의 숫자는 무한대로 발산하기 때문에 모든 각 단어들에 다른 Index를 부여할 수 없다. 따라서 모델이 처리할 수 있는 (아는) 단어를 정해주는데, 이를 Vocab이라고 하는 것이다.

 

그러면 모델이 처리할 수 있는 단어들만 따로 저장해두면 되지, 왜 단어들을 분리하는 것일까? Tokenize는 크게 Word Tokenizer 과 Subword Tokenizer로 구분된다. Word Tokenizer는 한 단어를 쪼개지 않고 토큰화 하는 것을 의미하며 Subword Tokenizer는 한 단어를 여러 개의 토큰으로 바꾸는 것을 의미한다. 즉 앞의 질문을 다른 말로 표현하면, Word Tokenizer만 쓰면 되지, 왜  Subword Tokenizer 을 사용하는 것일까?

 

예를 들어, Word Tokenizer는 아래와 같이 각 단어를 하나의 토큰으로 취급한다. 

경찰관 => 121
경찰복 => 232
경찰차 => 634
경찰소 => 32

 

반면, Subword Tokenizer는 한 단어를 유의미한 토큰으로 분리한다.

경찰관 => 경찰 ##관 => 12 561
경찰복 => 경찰 ##복 => 12 321
경찰차 => 경찰 ##차 => 12 1231
경찰소 => 경찰 ##소 => 12 3123

 

최근에는 대부분 Subword Tokenizer 가 사용되는데 그 이유는 아래와 같다.

 

1. 경우의 수를 줄여준다 (토큰의 재사용이 가능하다)

 

Word Tokenizer의 경우 아버지는, 아버지를, 아버지가, 아버지의, 아버지만 을 모두 다른 토큰으로 처리한다. 이는 대단한 낭비가 아닐 수 없다. 반면, Subword Tokenizer는 아버지를 하나의 토큰으로 처리하고, 뒤에 붙은 조사를 다른 토큰으로 처리한다. 따라서 토큰의 재사용이 가능하여 Vocab을 낭비하지 않을 수 있다.

 

2. OOV (Out of Vocabulary) 처리에 용이

 

OOV란 모델의 Vocab에 없는 단어를 의미한다. 예를 들어, 위 상황에서 경찰력이라는 새로운 단어가 나왔다면, Word Tokenizer의 경우 이를 [UNK] 토큰으로 처리하는 반면, Subword Tokenizer는 경찰 ##력 혹은 경찰 [UNK] 으로 처리할 것이다. 따라서 Subword Tokenizer 는  OOV를 마주했을 때 그 내재된 의미를 추출할 수 있다는 장점을 가지고 있다.

 

 

Subword Tokenizer 의 종류로 BPE (Byte Pair Encoding), Wordpiece, Unigram Language Model Tokenizer, Sentencepiece가 있는데 각각에 대한 설명은 여기를 참고하자.

 

BERT에서는 Wordpiece로 Vocab을 만든 뒤 Greedy Longest-Match-First Algorithm 방식을 사용하여 텍스트들을 Token 으로 바꾼다. Greedy Longest-Match-First Algorithm 은 아래의 예시를 보면 쉽게 이해할 수 있다. (이름만 거창)

 

[설명]

각 단어를 토큰화 할 때 Start (S) 와 End (E) 가 있다.
S, E는 각 단어의 Character Index를 의미한다.

단어 (S, E) : 타겟 -> 결과
타겟은 단어[S:E] 이다.
결과는 해당 토큰이 Vocab에 있는지 없는지이다. (있으면 O / 없으면 X)

결과가 X 이면 E를 1을 줄인다.
결과가 O 이면 당시 E 값을 S에 주고 다시 끝 Index부터 시작한다.
결과가 O 일 때 S가 0이 아니면 앞에 ##을 붙인다.
예시로 unaffable 이라는 단어를 토큰화 해보자.


[예시]

unaffable (0, 9) : unaffable → X

unaffable (0, 8) : unaffabl → X

...


un (0, 2) : un → O  => un

unaffable (2, 9) : affable → X

unaffable (2, 8) : affabl → X

...


aff (2, 5) : aff → O => ##aff

able (5, 9) : able → O => ##able


[결과]

unaffable => un ##aff ##able

 

 

5. Tokens > [CLS] Seq_A [SEP] Seq_B [SEP]

 

 

본 글의 초반에 말했지만, BERT는 위와 같은 Input을 가지며 BERT의 사전학습은 MASKED 된 토큰을 맞추고 다음 A와 B가 이어지는 Sequence인지 맞추는 과정이다. 따라서 토큰들을 모델이 학습할 수 있는 정보들을 담은 하나의 객체로 저장해주어야 한다. 이때 저장하는 확장자로는 npy, hdf5, lmdb 등을 다양하게 사용한다. 모두 크기가 큰 데이터들을 저장하는데 용이한 확장자들이다.

 

 

5.1. Target Sequence Length

BERT의 Input 형태는 [CLS] Seq_A [SEP] Seq_B [SEP] 꼴이다. Seq_A 와 Seq_B 는 일련의 토큰들의 집합이다. 따라서 A 와 B에 각각 몇 개의 토큰이 들어갈지 결정해주어야 한다. Input의 Max Sequence Length가 128이라면, [CLS] 1개와 [SEP] 2개를 제외하면 총 125개의 토큰을 채워야 한다. 이때 125를 Target Sequence Length라고 한다.

 

다만 BERT에서는 학습의 다양성을 위해 0.1의 확률로 Target Sequence Length를 2 이상 127 이하의 정수로 선택하게 된다.

import random

rng = random.Random(args.random_seed)  # Give random a seed (To maintain the result)

target_seq_length = max_num_tokens - 3  # 125 = 128 - 3

if rng.random() < 0.1: 
    target_seq_length = rng.randint(2, max_num_tokens)  # 0.1의 확률로 125보다 작은 것

 

 

5.2. 토큰들의 개수가 Target Sequence Length를 넘을 때까지 문장들을 더함

한 문서가 총 26개의 문장 (S1, S2, .. S26)으로 이루어졌으며, 각 문장은 10개의 토큰들로 이루어진 상황을 가정해보자.

그러면 S1부터 각 토큰의 합이 Target Sequence Length 를 넘을 때까지 Chunk에 이어서 넣는다. Chunk는 그냥 Input의 후보라고 생각하면 된다. 그러면 현재 Chunk에는 S13가지 들어가 있을 것이다. 

current_chunk = []
current_length = 0
i = 0

while i < len(document):  # 한 문서에 있는 문장의 개수
    segment = document[i]  # 한 문장
    current_chunk.append(segment)  # 한 문장 append
    current_length += len(segment)  # 한 문장에 있는 토큰의 개수를 더함

    # 마지막 문장이거나 거나 / 현재 청크의 토큰 개수가 이미 target 이상이면 (보통 125)
    if i == len(document) - 1 or current_length >= target_seq_length:
    	break
    else:
    	i += 1

 

5.3. Chunk의 문장들 중 분기점을 정함

 

Chunk 에 있는 문장들 중 랜덤으로 분기점을 정한다. 위의 예시에서는 1~13 사이의 숫자가 분기점이 될 것이다. 만약 분기점이 6이었다면, S1 ~ S5 가 Seq_A 로 들어가게 된다.

# `a_end` is how many segments from `current_chunk` go into the `A`
if len(current_chunk) >= 2:  # number of sentences added to chunk
	a_end = rng.randint(1, len(current_chunk) - 1)  # 분기점을 정함
else:
	a_end = 1

seq_a = []
for j in range(a_end):  # 분리점으로 이전 segment 들은 Token_A 에 포함
	seq_a.extend(current_chunk[j])

 

 

5.4. Random Next 유무

 

(1) Random Next가 아닌 경우

 

Seq_A에 들어가지 않은 Chunk의 나머지 문장들을 Seq_B에 넣는다. 위의 예시에서는 S6 ~ S13이  Seq_B에 들어갈 것이다.

 

is_random_next = False
for j in range(a_end, len(current_chunk)):
	seq_b.extend(current_chunk[j])

 

 

(2) Random Next인 경우

 

BERT 에서는 0.5의 확률로 Seq_A 다음에 랜덤 토큰들이 Seq_B로 오게 된다. 이때 Seq_B를 정하는 방법은 완전 다른 문서에서 일련의 토큰들을 가져오는 것이다.

 

if len(current_chunk) == 1 or rng.random() < 0.5:  # 청크에 문장이 1개이거나 50%의 확률로 다른 문장 이어붙임
    
    is_random_next = True
    target_b_length = target_seq_length - len(seq_a)  # Seq_B 에 들어갈 토큰의 개수를 구한다

    for _ in range(10):  # B에 들어갈 토큰들을 가져올 문서를 정함
        random_document_index = rng.randint(0, len(all_documents) - 1)  # 무작위 문서
        if random_document_index != document_index:  # A를 가져온 문서와 다른 문서여야 한다
            break 

    # If picked random document is the same as the current document
    if random_document_index == document_index:
        is_random_next = False  # 혹여나 같으면 random_next를 False로 만든다

    random_document = all_documents[random_document_index]
    random_start = rng.randint(0, len(random_document) - 1)  # 고른 문서에서 랜덤 시작점 고름

    for j in range(random_start, len(random_document)):
        tokens_b.extend(random_document[j])  # 토큰화된 문장 하나씩 넣는다
        if len(seq_b) >= seq_b_length:  # Seq_B
            break

    # 분기점 이후로 문장들은 실제로 안썼다 (S6~S13) -> 다시 보관한다
    num_unused_segments = len(current_chunk) - a_end
    i -= num_unused_segments
    

 

 

5.5. Truncate

Seq_A + Seq_B 가 Target Sequence Length 보다 길면 Truncate 하는 과정을 거친다.

Truncate 할 때는 A와 B 중 더 긴 것의 앞이나 뒤에서 토큰을 하나씩 제거해준다.

def truncate_seq_pair(seq_a, seq_b, target_seq_length, rng):
    """Truncates a pair of sequences to a maximum sequence length."""
    
    while True:
        total_length = len(seq_a) + len(seq_b)  
        if total_length <= target_seq_length:  # 합이 target_seq_length 보다 클 때
            break

        trunc_tokens = seq_a if len(seq_a) > len(seq_b) else seq_b  # 더 긴 것 선택
        assert len(trunc_tokens) >= 1

        # We want to sometimes truncate from the front and sometimes from the
        # back to add more randomness and avoid biases.
        if rng.random() < 0.5:  # 0.5의 확률로 
            del trunc_tokens[0]  # 맨 앞 토큰 제거 
        else:
        	trunc_tokens.pop()  # 맨 뒤 토큰 제거

 

위의 일련의 과정들을 모두 거치게 되면 아래와 같은 결과를 얻을 수 있다.

[CLS] Seq_A [SEP] Seq_B [SEP]

  0   0...0   0   1...1   1

여기서 0과 1은 해당 토큰이 Seq_A의 토큰인지, Seq_B의 토큰인지 나타내 주는 지표이다.

 

 

6. [CLS] Seq_A [SEP] Seq_B [SEP] > Masked Sequence

 

이제 남은 것은 일부 토큰들을 Masking 하는 것이다. 전체 코드는 아래와 같다.

def create_masked_lm_predictions(tokens, masked_lm_prob,  max_predictions_per_seq, vocab_words, rng):
    """Creates the predictions for the masked LM objective."""

    cand_indexes = []
    for (i, token) in enumerate(tokens):
        if token == "[CLS]" or token == "[SEP]":  # [CLS], [SEP]는 마스킹 후보에서 제외
            continue
        cand_indexes.append(i)

    rng.shuffle(cand_indexes)

    output_tokens = list(tokens)

    # 몇개를 마스킹 할 것이냐 -> 토큰 개수 * 0.15 (최대 20개)
    num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))))

    masked_lms = []  # 마스킹 된 토큰 객체 (인데스와 라벨이 포함되어있다)
    covered_indexes = set()  # 마스킹할 토큰의 인덱스
    for index in cand_indexes:
        if len(masked_lms) >= num_to_predict:  # 목표치 채웠으면 끝내고
            break
        if index in covered_indexes:  # 이미 마스킹된 index 면 건너뛰기
            continue
        covered_indexes.add(index)  # 해당 Index 마스킹 할 것이다.

        # 80% of the time, replace with [MASK]
        if rng.random() < 0.8:
            masked_token = "[MASK]"  # 0.8의 확률로 마스킹
        else:
            if rng.random() < 0.5:
                masked_token = tokens[index]  # 0.1의 확률로 그대로
            else:
                masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]  # 0.1의 확률로 랜덤 토큰

        output_tokens[index] = masked_token  # 마스킹된 토큰으로 교체
        masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))  # 마스킹한 토큰의 인덱스와 라벨 저장

    masked_lms = sorted(masked_lms, key=lambda x: x.index)  # Index 로 정렬

    masked_lm_positions = []
    masked_lm_labels = []

    for p in masked_lms:
        masked_lm_positions.append(p.index)  # 마스킹한 토큰의 인덱스 저장
        masked_lm_labels.append(p.label)  # 마스킹한 토큰의 라 저장

    # (전체 토큰들 - 마스킹 반영, 마스킹 된 토큰의 위치, 마스킹된 토큰의 라)
    return output_tokens, masked_lm_positions, masked_lm_labels

 

Seq_A와  Seq_B의 토큰들 중 몇 개를 마스킹할지 우선 정한다. 그 후 랜덤으로 해당 개수만큼 토큰을 골라 마스킹하고 정답과 그 토큰의 위치를 저장한다. 특이점은 마스킹할 토큰 중 0.8의 확률로 [MASK], 0.1의 확률로 그대로, 0.1의 확률로 다른 토큰 (랜덤)으로 바꾼다는 것이다.

 

아래의 예시를 참고하자.

 

<Original Sequence>

[CLS] 아버지 ##가 방 ##에 [UNK] 들 ##어 기 ##셨 ##다 [SEP] 코딩 ##은 재미 ##있 ##다 [SEP]


<Masked Sequence>

[CLS] [MASK] ##가 방 ##에 [UNK] 들 ##어 [Mask] ##셨 ##다 [SEP] 코딩 ##은 재미 ##있 ##다 [SEP]


<Masked Token Index>

1, 8
 
 
<Masked Token Label>

아버지, 가

 

 

7. Masked Sequence > Encoding, Padding

 

모든 토큰들 ( [CLS] Seq_A [SEP] Seq_B [CLS] / 마스킹된 토큰의 라벨들 )을 Vocab에서의 Index로 바꾼다.

아래 과정을 거친다고 생각하면 된다.

 

<Masked SequencE>

[CLS] [MASK] ##가 방 ##에 [UNK] 들 ##어 [Mask] ##셨 ##다 [SEP] 코딩 ##은 재미 ##있 ##다 [SEP]


<Encoded Sequence>

 2      4    523  8 312   1   53 5234   4   323  123   3    21  33  22   21  123   3
 
 
<Original Label>
 
1, 8
 
 
<Encoded Label>

3444, 553

 

또한 전체 Sequence의 길이가 128 보다 작으면 남은 공간들을 [PAD] 토큰으로 채운다.

 

while len(masked_lm_positions) < max_predictions_per_seq:  # Fill empty space in tokens [PAD]
	
    masked_lm_positions.append(0)
	masked_lm_ids.append(0)
  

 

 

8. Encoded & Padded Sequence > Training Instance

 

Training Instance에 들어가는 정보는 아래와 같다.

 

  • 마스킹 된 토큰들의 id
  • 마스킹 된 토큰의 위치
  • 마스킹 된 코튼의 라벨
  • Random Next (True / False)
  • 0 ... 0 , 1 ... 1 (Seq_A 와 Seq_B 구분)

예시를 참고하자

 

<입력 Sequence - 인코딩 됨>

 2      4    523  8 312   1   53 5234   4   323  123   3    21  33  22   21  123   3


<입력 Mask> (Padding을 마스킹하기 위해서)

1 ... 1 0 ... 0 (뒤에 0들은 padding 된 index들)


<마스킹 된 토큰의 라벨 - 인코딩 됨>

3444, 553


<마스킹 된 토큰의 위치>

1, 8


<Random Nex>

1 (True)


<Sequence ID> (해당 토큰이 Seq_A에 속하는지 Seq_B에 속하는지)

0 ... 0 1 ... 1

 

위의 정보들을 한 객체에 저장된다. 코드로는 아래와 같다.

features["input_ids"][inst_index] = input_ids  # 토큰 -> id
features["input_mask"][inst_index] = input_mask  # 패딩 Masking
features["segment_ids"][inst_index] = segment_ids  # 해당 토큰이 A인지 B인지
features["masked_lm_positions"][inst_index] = masked_lm_positions  # 어떤 토큰이 마스킹 됐는지 (INDEX)
features["masked_lm_ids"][inst_index] = masked_lm_ids   # 마스킹 된 토큰의 라벨 -> id
features["next_sentence_labels"][inst_index] = next_sentence_label  # 랜덤 문장 : 1 / 랜덤 문장 X : 0

 

 

정리하며

생각보다 아는 지식을 글로 쉽게 전달하기가 힘들다. 지식이 덜 계층화되어서 그런가? 아무튼 위의 길고 길었던 일련의 과정들을 거쳐 드디어 Raw Data가 BERT에 학습시킬 수 있는 형태가 되었다. 2편에서는 이를 사용하여 어떻게 BERT를 학습시키는지 알아보자. 

반응형