본문 바로가기

논문/Multimodal

LayoutLM: Pre-training of Text and Layout for Document Image Understanding

Abstract

Pre-training 기법은 다양한 NLP 작업에서 최근 몇 년 간 효과적이라고 검증되었다. 그러나 대부분의 NLP application은 text-level에 집중되어있고 문서의 layout이나 style은 무시되었다. 본 논문에서는 LayoutLM을 제안하며, text와 문서 image로부터 scan된 layout 정보의 상호작용을 효과적으로 연결한다. 또한 image feature를 leverage하여 단어의 visual 정보를 모델에 잘 녹였다. 이는 최초로 text와 layout 정보를 single framework에서 document level pre-training을 jointly하게 수행한 몇몇 downstream task에서 비약적인 성능 향상을 보였다.

1. Introduction

Document AI 또는 Document Intelligence는 상대적으로 새로운 연구 분야로 대표적으로 Figure 1. 처럼 비지니스 문서(purchase orders, financial report, business emails 등등)를 자동으로 읽고 이해하는 기술을 말한다.

Figure 1: Scanned images of business documents with different layouts and formats

비즈니스 문서의 정확한 형식은 다를 수 있지만 정보는 일반 텍스트, 다중 열 layout 및 다양한 table/form/figure에서 다양한 방식으로 제공된다. 비즈니스 문서를 이해하는 것은 layout과 형식의 다양성, 스캔한 문서 이미지의 품질 저하, 템플릿 구조의 복잡성으로 인해 매우 어려운 작업이다.

과거에는 전체적인 구조보단 문서의 특정 부분 예를들어 tabular areas에 집중했다.

  • 7: CNN 기반의 pdf문서의 table detection
  • 21, 24, 29: Faster R-CNN 또는 Mask R-CNN을 사용하여 문서의 layout 분석의 정확도를 향상
  • 28 : pre-trained NLP 모델의 text 임베딩을 활용하여 문서 이미지에서 semantic structures를 추출하기 위한 end-to-end, multimodal, fully convolutional network
  • 15 : GCN 기반의 모델로 textual 과 visual 정보를 추출하여 합침
  • 위 작업들의 한계점
  1. human-labed training samples에 의존
  2. pre-trained CV or NLP 모델을 leverage 하기만 했으며 joint training을 시도하지 않음

그래서 LayoutLM을 제안

Figure 2: An example ofLayoutLM, where 2-D layout and image embeddings are integrated into the original BERT architecture. The LayoutLM embeddings and image embeddings from Faster R-CNN work together for downstream tasks.

  1. 문서 내에서 token의 상대적 위치를 나타내는 2-D position embedding. : token 간 관계를 포착
  2. 문서 내의 스캔된 token images에 대한 image embedding. : appearance feature(font directions, types, and colors 등) 포착
  3. MVLM(Masked Visual-Language Model) 손실 및 MDC(Multi-label Document Classification) 손실을 포함하여 LayoutLM에 대한 multi-task learning objective를 채택하여 text 및 layout에 대한 joint pre-training을 수행

LayoutLM은 컴퓨터로 작성된 문서보다는 더욱 어려운 "IIT-CDIP Test Collection 1.0"(11M scanned document images, 6M scanned documents)을 이용하여 pre-train을 수행.

Downstream benchmark로는 세 가지를 선택

  1. FUNSD: spatial layout 분석과 form understaining
  2. SROIE: Scanned Receipts Information Extraction
  3. RVL-CDIP:document image 분류 (400,000 grayscale images in 16 classes)

Contrbutions

  1. 최초로 document image로부터 textual과 layout 정보를 single framework로 pre-train하였으며 image feature로 이를 leverage.
  2. masked visual-language model (MVLM) 과 multi-label document classification(MDC)를 학습 objectives로 사용

2. LayoutLM

간략한 BERT의 review와 어떻게 text와 layout information을 joint train했는지 소개

2.1 The BERT Model

BERT는 attention-based bidirectional language modeling approach이다. BERT는 대규모의 training set을 self-supervised task 에 이용하여 효과적으로 knowledge transfer를 성공하였다. 기본적인 아키텍처는 multi-layer bidirectional Transformer encoder이다. 모델은 token의 sequence를 입력으로 받아 최종 representation을 출력한다. 입력 token은 Wordpiece로 생성되며 입력 embedding은 word embedding과 position embedding, segment embedding을 더하여 생성한다.

BERT는 pre-training과 fine-tuning이라는 두 단계가 있다. pre-training 단계에서는 두 개의 objectives를 사용하여 language representation을 학습한다. (Masked Language Modeling (MLM) and Next Sentence Prediction (NSP)) fine-tuning 단계에서는 task-specific한 dataset을 이용하여 모든 parameter를 end-to-end way로 update한다.

2.2 The LayoutLM Model

BERT-like model들이 많은 NLP task에서 성공을 거두었지만 visually rich document는 text 이외에도 더 많은 정보를 pre-trained model에 encode할 수 있다. 기본적으로 visually 풍부한 documents에서 language representation을 크게 향상시키는 두 가지 유형의 feature가 있다.

  • Document Layout Information
    • 문서 내에서의 단어의 상대적인 위치는 많은 semantic representation을 포함
    • 예로 Passport ID: 라는 글자의 key 값은 일반적으로 해당 글자의 오른쪽 혹은 아래에 위치
  • Visual Information
    • Visual information는 image feature으로 표현될 수 있으며 document representation에 효과적으로 활용
    • Document-level의 visual feature는 전체 image가 document layout을 나타낼 수 있으며 이는 document image classification과 같은 작업에서 중요한 역할
    • Word-level에서는 font의 style(bold, underline, italic)이 labeling task에서 중요한 힌트를 줄 수 있다.

2.3 Model Architecture

BERT 아키텍처 기반에 2-D position embedding과 image embedding을 사용했다.

  • 2-D Position Embedding
    • 2-D position embedding은 문서 내의 상대적인 위치를 알리기 위함
    • The bounding box는 $(x_0, y_0, x_1, y_1)$로 정의할 수 있으며, 이는 box의 좌측 상단$(x_0, y_0)$과 우측 하단 $(x_1, y_1)$의 값으로 구성
    • 위의 4개의 값에 대해 각각 X, Y look up embedding table 2가지를 이용하여 총 4개의 embedding layer를 구성
  • Image Embedding
    • 문서의 image feature을 text와 정렬하기 위해 language representation에서 image feature를 나타내는 image embedding layer를 추가
    • 더 자세하게는 OCR 결과에서 각 단어의 bounding box를 사용하여 이미지를 여러 조각으로 분할하고 단어와 일대일로 대응
    • Token image embedding으로 Faster R-CNN 모델의 image piece으로 image region feature를 생성
    • [CLS] 토큰의 경우 Faster R-CNN 모델을 사용하여 [CLS] 토큰의 표현이 필요한 downstream tasks에 도움이 되도록 전체 scanned document image를 region of interest(ROI)로 사용하여 embedding을 생성

2.4 Pre-training LayoutLM

  • Task #1: Masked Visual-Language Model(MVLM)
    2-D position embeddings와 text embeddings를 잘 학습하기 위함으로, pre-training 동안 임의로 input tokens를 masking 하지만, 해당 2-D position embeddings를 유지한 다음 주어진 context에서 masking된 token을 예측하도록 모델을 훈련
  • Task #2: Multi-label Document Classification(MDC)
    Document-level representation은 document image understanding을 위해 피룡하다. IIT-CDIP Test Collection은 각각 document image에 대해 multiple tags가 되어있으며, Multi-label Document Classification loss를 pre-training phase에서 사용한다. 각각 image에 대해 대규모 dataset에는 없을 수 있는 label이 되어있어야 하므로 큰 모델 구축시에는 제외할 수도 있다

2.5 Fine-tuning LayoutLM

Pre-trainged LayoutLM은 세 가지 document image understanding tasks에 대해 fine-tuning되었다. 각각은 form understanding, receipt understanding, document image classification이다. Form과 receipt understanding tasks에서 LayoutLM은 dataset의 각 entity 유형과 {B, I, E, S, O} tag를 예측한다. Document classification task의 경우 [CLS] token의 representation을 사용하여 예측한다.

출처) http://ml.cau.ac.kr/activities/outputs/210803-BIOES,%20Conditional%20Random%20Field.pdf

3. Experiments

3.1 Pre-training Dataset

IIT-CDIP Test Collection 1.0 (https://data.nist.gov/od/id/mds2-2531) 6M documents with more thana 11M scanned document images. 또한 각각의 document는 OCR로 생성된 text는 물론 metadata가 XML 형태로 저장되어 있다. 비록 data가 약간의 오류와 inconsistency tagging을 포함하지만 이런 대규모의 scanned document image는 본 모델의 pre-training에 최적이다.

3.2 Fine-tuning Dataset

  • The FUNSD Dataset
    • Form understanding in noisy scanned documents
    • 199(149 train, 50 test) real, fully annotated, scanned forms with 9,707 semantic entities and 31,485 words
    • 각각의 entity는 label과 bounding box 그리고 entities 간의 link가 있다.
  • The SROIE Dataset
    • Receipt information extraction
    • 626 (train) and 347 (test)
    • List of textline with bounding boxes, {company, date, address, total}로 label이 되어 있다. (이외는 None)
  • The RVL-CDIP Dataset
    • 400,000(320,000 train, 40,000 test, 40,000 validation) grayscale images in 16 classes, with 25,000 images per class
    • resized되어 1,000 pxl을 넘지 않는다.
    • 16 classes는 다음과 같다. {letter, form, email, handwritten, advertisement, scientific report, scientific publication, specification, file folder, news article, budget, invoice, presentation, questionnaire, resume, memo}

FUNSD 예시

{
        "form": [
        {
            "id": 0,
            "text": "Registration No.",
            "box": [94,169,191,186],
            "linking": [
                [0,1]
            ],
            "label": "question",
            "words": [
                {
                    "text": "Registration",
                    "box": [94,169,168,186]
                },
                {
                    "text": "No.",
                    "box": [170,169,191,183]
                }
            ]
        },
        {
            "id": 1,
            "text": "533",
            "box": [209,169,236,182],
            "label": "answer",
            "words": [
                {
                    "box": [209,169,236,182
                    ],
                    "text": "533"
                }
            ],
            "linking": [
                [0,1]
            ]
        }
    ]
    }{
        "form": [
        {
            "id": 0,
            "text": "Registration No.",
            "box": [94,169,191,186],
            "linking": [
                [0,1]
            ],
            "label": "question",
            "words": [
                {
                    "text": "Registration",
                    "box": [94,169,168,186]
                },
                {
                    "text": "No.",
                    "box": [170,169,191,183]
                }
            ]
        },
        {
            "id": 1,
            "text": "533",
            "box": [209,169,236,182],
            "label": "answer",
            "words": [
                {
                    "box": [209,169,236,182
                    ],
                    "text": "533"
                }
            ],
            "linking": [
                [0,1]
            ]
        }
    ]
}

entity 단위로 id와 label이 매겨지고 bounding box를 정의한 후 다른 entity와의 관계를 linking으로 정의했다.

3.3 Document Pre-processing

각 문서의 layout information을 활용하기 위해서는 각 token의 위치를 얻어야 한다. 그러나 pre-training dataset(IIT-CDIP Test Collection)에는 해당 bounding box가 누락된 순수 텍스트만 포함되어 있기 때문에 IIT-CDIP Test Collection의 원래 전처리와 마찬가지로 document image에 OCR을 적용하여 dataset를 유사하게 처리하면서 layout information도 얻는다. 오픈 소스 OCR 엔진인 Tesseract 덕분에 2-D position뿐만 아니라 recognition도 쉽게 얻을 수 있다. OCR 결과는 hierarchical representatin을 사용하여 single document image의 OCR 결과를 명확하게 정의하는 표준 사양 형식인 hOCR 형식으로 저장됩니다.

3.4 Model Pre-training

  • Pre-trained BERT base model로 parameter initialize
  • BERT base와 같은 아키텍처 (12-layer Transformer with 768 hidden sizes, and 12 attention heads, which contains about 113M parameters)
  • 15%의 input token을 예측했으며 이 중 80%를 [Mask] token으로 10%를 random token으로 10%는 변화를 주지 않았다
  • 2-D position embedding은 $(x_0, y_0, x_1, y_1)$의 embedding representation으로 행해졌으며, document layout이 page size마다 달라질 수 있기 때문에 (0, 1000) 값을 가지도록 scaled 되었다.
  • Visual Genome dataset으로 pre-trained된 ResNet-101 모델을 backbone으로 가지는 Faster R-CNN을 사용하였다.
  • 8ea NVIDIA Tesla V100 32GB GPUs를 이용, batch size 80, Adam optimizer, lr 5e-5 linear decay lr 조건에서 BASE 11M documents one epoch 당 80hr, LARGE 170hr

3.5 Task-specific Fine-tuning

  • Form Understanding
    • Two sub-tasks: semantic labeling & semantic linking
      • Semantic labeling: NER
      • Semantic linking: Entity 간의 relation 예측
    • 본 논문에서는 semantic labeling task만 집중했으며, 이를 sequence labeling 문제로 보았다.
    • Final representation을 linear layer 후 softmax layer 에 넣어 출력값으로 label을 예측
    • The model is trained for 100 epochs with a batch size of 16 and a learning rate of 5e-5
  • Receipt Understanding
    • Scanned receipt images에 대한 semantic slots(예로 company address, date and total)을 예측

SRIOE 예시

  • Document Image Classification
    • Document image의 category를 예측
    • 기존과는 다르게 해당 예측에 multi-modal로 접근
    • Fine-tune the model for 30 epochs with a batch size of 40 and a learning rate of 2e-5.

3.6 Results

Table 1: Model accuracy (Precision, Recall, F1) on the FUNSD dataset

FUNSD benchmark를 BERT와 RoBERTa 두 NLP model을 baseline으로 하여 비교했다. 결과는 RoBERTa가 훨씬 좋았으며 4가지 조건으로 LayoutLM을 test했다. 또한 MDC loss를 pre-training step에 추가할 경우 FUNSD 성능이 향상되는 것을 확인하였다. 또한 text+layout+image embedding을 모두 사용하였을 때 제일 좋은 성능을 보여주었다.

Table 2: LayoutLMBASE (Text + Layout, MVLM) accuracy with different data and epochs on the FUNSD datase

dataset과 epoch를 split하여 실험 진행한 결과를 table 2에 수록하였다. 결과는 epoch 수가 증가할수록 성능이 향상되는 것을 확인할 수 있다. 또한 학습에 사용된 데이터의 수가 많을 수록 더 좋은 성능을 보여주었다. 즉 FUNSD의 training data가 149장 밖에 없는 것을 고려해볼 때 text와 layout의 pre-training이 low-resource scanned document understanding에 효과적이라는 것을 알 수 있다.

Table 3: Different initialization methods for BASE and LARGE (Text + Layout, MVLM)

LayoutLM이 RoBERTa로 초기화 되었을 경우 더 좋은 성능을 보여주는 것을 확인할 수 있다.

Table 4: Model accuracy (Precision, Recall, F1) on the SROIE dataset

receipt understaning은 SROIE dataset을 이용하여 평가하였다. SROIE의 key information task만 수행하였다. OCR 결과의 부정확성 없애기 위해 ground truth OCR 결과만 사용하여 학습하였다. 11M 데이터가 가장 좋은 성능을 보여준 것에서 FUNSD와 같이 pre-training이 해당 task에도 효과적이라는 결론을 낼 수 있다.

Table 5: Classification accuracy on the RVL-CDIP dataset

Document image classification은 RVL-CDIP dataset으로 이루어졌으며, 해당 작업에서 BERT와 RoBERTa가 다른 image-based model보다 성능이 떨어지는 것에서부터 text information만으로는 해당 작업을 해결하는데 부족하다는 것을 확인할 수 있다. LayoutLM의 경우 image feature 없이도 기타 single image-based approach 성능을 상회하는 것을 확인할 수 있다. 그리고 Image feature를 사용할 경우 SOTA baseline을 상회하는 것을 알 수 있다.

5. Conclusion and Future Work

본 논문에서 LayoutLM을 제안하며 이 모델은 single framework에서 text와 layout information을 한번에 pre-training한다. Transformer 아키텍처를 기반으로해서 multi-modal input(token embeddings, layout embeddings image embeddings)의 이점을 가져올 수 있다. 해당 모델은 대규묘 unlabeled scanned document images에 대해 self-supervised 발향성을 제시한다. LayoutLM은 form understanding, receipt understanding, document image classification downstream tasks에 대해 평가하였으며 몇몇 SOTA 모델들을 상회했다.

유첨. code

import logging

import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import BertConfig, BertModel, BertPreTrainedModel
from transformers.modeling_bert import BertLayerNorm

logger = logging.getLogger(__name__)

LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_MAP = {}

LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}


class LayoutlmConfig(BertConfig):
    pretrained_config_archive_map = LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP
    model_type = "bert"

    def __init__(self, max_2d_position_embeddings=1024, **kwargs):
        super().__init__(**kwargs)
        self.max_2d_position_embeddings = max_2d_position_embeddings #기본 1024로 embedding


class LayoutlmEmbeddings(nn.Module):
    def __init__(self, config):   
        super(LayoutlmEmbeddings, self).__init__()
        '''
        torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, ...)

        num_embeddings (int) – size of the dictionary of embeddings
        embedding_dim (int) – the size of each embedding vector
        padding_idx (int, optional) – If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. For a newly constructed Embedding, the embedding vector at padding_idx will default to all zeros, but can be updated to another value to be used as the padding vector.
        '''
        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.hidden_size, padding_idx=0
        )
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size
        )
        self.x_position_embeddings = nn.Embedding(
            config.max_2d_position_embeddings, config.hidden_size
        )
        self.y_position_embeddings = nn.Embedding(
            config.max_2d_position_embeddings, config.hidden_size
        )
        self.h_position_embeddings = nn.Embedding(
            config.max_2d_position_embeddings, config.hidden_size
        )
        self.w_position_embeddings = nn.Embedding(
            config.max_2d_position_embeddings, config.hidden_size
        )
        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        input_ids,
        bbox,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
    ):
        seq_length = input_ids.size(1)
        if position_ids is None:
            position_ids = torch.arange(
                seq_length, dtype=torch.long, device=input_ids.device
            )
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])  #x0
        upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) #y0
        right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) #x1
        lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) #y1
        h_position_embeddings = self.h_position_embeddings(
            bbox[:, :, 3] - bbox[:, :, 1]
        ) #height
        w_position_embeddings = self.w_position_embeddings(
            bbox[:, :, 2] - bbox[:, :, 0]
        ) #width
        token_type_embeddings = self.token_type_embeddings(token_type_ids) # 아마 SROIE의 key값{company, date, address, total} 값들을 추가하는 것으로 보여짐

        embeddings = (
            words_embeddings
            + position_embeddings
            + left_position_embeddings
            + upper_position_embeddings
            + right_position_embeddings
            + lower_position_embeddings
            + h_position_embeddings
            + w_position_embeddings
            + token_type_embeddings
        ) # 그림과 다르게 w, h 도 추가되고 token type도 들어감
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class LayoutlmModel(BertModel):

    config_class = LayoutlmConfig
    pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_MAP
    base_model_prefix = "bert"

    def __init__(self, config):
        super(LayoutlmModel, self).__init__(config)
        self.embeddings = LayoutlmEmbeddings(config)
        self.init_weights()

    def forward(
        self,
        input_ids,
        bbox,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.parameters()).dtype
        )  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = (
                    head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                )
                head_mask = head_mask.expand(
                    self.config.num_hidden_layers, -1, -1, -1, -1
                )
            elif head_mask.dim() == 2:
                head_mask = (
                    head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                )  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.config.num_hidden_layers

        embedding_output = self.embeddings(
            input_ids, bbox, position_ids=position_ids, token_type_ids=token_type_ids
        )
        encoder_outputs = self.encoder(
            embedding_output, extended_attention_mask, head_mask=head_mask
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

        outputs = (sequence_output, pooled_output) + encoder_outputs[
            1:
        ]  # add hidden_states and attentions if they are here
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)


class LayoutlmForTokenClassification(BertPreTrainedModel):
    config_class = LayoutlmConfig
    pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_MAP
    base_model_prefix = "bert"

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = LayoutlmModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids,
        bbox,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):

        outputs = self.bert(
            input_ids=input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[
            2:
        ]  # add hidden states and attention if they are here
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), scores, (hidden_states), (attentions)


class LayoutlmForSequenceClassification(BertPreTrainedModel):
    config_class = LayoutlmConfig
    pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_MAP
    base_model_prefix = "bert"

    def __init__(self, config):
        super(LayoutlmForSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels

        self.bert = LayoutlmModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob) #기본 0.5
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids,
        bbox,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):

        outputs = self.bert(
            input_ids=input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[
            2:
        ]  # add hidden states and attention if they are here

        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)