이 논문은 새로운 방식의  bbox regression 방식인 DFL(distribution focal loss) 와 localization quality와 classification score 동시에  표현해 최적화 하는 quality focal loss를 제안하고 이 둘을 합쳐 Generalized Focal loss라고 명명한다.

localization quality란 FCOS 관점에서 보면 object의 centerness score에 해당한다.


Tile: Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection

 

Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection

One-stage detector basically formulates object detection as dense classification and localization. The classification is usually optimized by Focal Loss and the box location is commonly learned under Dirac delta distribution. A recent trend for one-stage d

arxiv.org

git:https://github.com/implus/GFocal/blob/master/mmdet/models/losses/gfocal_loss.py

 

기존 방식의 문제점:

문제 1. 학습/추론 시 localization quality estimation 과 classification score 의 사용 방식이 불일치


사유 1. FCOS 같은 one stage detector들은 학습시 classification score 와 centerness(또는 iou)score 가 별개로 학습되지만 inference 시에는 nms전에 두 score를 join해서 사용(element wise multiplication)한다. 위 Figure 1의 (a)의 train와 test이 이를 잘 보여준다.
사유 2. positive sample 위치에만 localization quality estimation에 대한 label이 주어진다. 학습과 추론 단계에서의 두 score의 학습/사용 방식이 상이한 점이 성능 저하로 이어 질 수 있음

문제 2. 박스 표현의 경직성(Inflexible representation of bouding boxes)
기존 방식들은 positive sample 위치에만 box gt를 할당해 regression 하는 방식을 취하는데 이는 dirac delta distribution으로 볼 수 있다. 이유는 이런 단순 한 box gt 할당은 database에 존재 할 수 있는 다양한 애매하고 불명확한 상황을 고려하지 못하기 때문이다. 예를 들어 물체의 occlusion, shadow, blur등으로 인해 물체의 경계가 불분명 해 질 수 있고 이 경우 Dirac delta distribution은 이런 경우를 커버하기엔 제한 적이다.

위 Figure 3에서 보면 이 논문에서 제안하는 방식으로 bbox 표현을 학습 하면 가림, 그림자, 흐림 등의 경우에도 target 물체의 모양을 고려해 더 fit한 bbox를 추측 할 수 있다는 것을 나타낸다. 왼쪽 그림의 경우 서핑 보드가 파도에 의해 가림이 생겼는데 이 논문에서 제안한 방식으로 학습한 경우 녹색과 같이 박스가 추측 된다.
이미지 오른 쪽의 그래프는 아직 이해 하지 못해도 된다. 이것 distribution focal loss파트를 읽어 보면 이해 되는 그래프이다.

몇몇 논문에서 bbox를 gaussian distribution으로 표현해 학습 하는 방법을 제안했지만 이는 단순 한형태여서 다양한 상황을 커버하지 못한다(고 주장한다)


Method

Quality Focal Loss(QFL)


quality focal loss는 localization quality estimation과 classification score를 혼합한 classification-iou score를 최적화 하기 위한 loss function으로 위 “기존 방식의 문제점” 섹션에서 언급한 문제 1의 train-test inconsistency를 해결한다.

localization quality estimation과 classification score를 혼합했기 때문에 one-hot category label이 아닌 soften된 label $y \in [0,1]$ 이 사용된다. $y=0$은 negative sample로 0 quality score(IoU score) 를 나타내고, $ 0 < y \leq 1 $ 은 positive sample로 quality score y가 loss의 target y로 사용된다.

여기서 localization quality estimation과 classification score를 혼합하면 왜 one-hot category label 이 아니라 soften된 label 이 target이 되는지 의문이 들 수 있다. 나도 처음엔 이게 의문이었다.  이유는 classification score(one-hot)에 각 positive sample의 위치 anchor(또는 center position)에 해당하는 pixel에서 추론된 predicted bbox와 target bbox의 IoU score를 곱해서 target으로 사용 하기 때문이다. (one-hot label에서 positive sample의 label은 1이니까 IoU score곱한다는 의미는 IoUscore를 target label로 사용하겠다는 것과 같다). 이게 위 Figure 4의 existing work 과 GFL의 차이에서 label이 soften된다고 나타낸 이유이다.

QFL의 수식은 아래와 같다.
$$QFL(\sigma) =  -\left \vert y - \sigma \right \vert ^{\beta} ((1-y)\log{(1-\sigma)} + y\log{\sigma}$$

focal loss에서 달라진 부분은 두 부분이다.
1. cross entropy part인 $-\log{(p_t)}$ 가 binary classification의 complete form인 $1((1-y)\log{(1-\sigma)}+y\log{\sigma}$ 로 바뀌었다.
2. scaling factor $(1-p_t)^gamma$ 가 추정치 $\sigma$와 label $y$의 L1 distance $\left \vert y - \sigma \right \vert ^ \beta$ 로 바뀌었다.

위 형태에서 $y=\sigma$일 때 global minimum을 갖는다.

multi-class classification의 경우 sigmoid를 이용해 multiple binary classification으로 문제를 정의 한다. multiple binary classification에서 각각의 binary classification을 위와 같은 방식으로 풀면되므로 어려울 건 없다.

코드는 아래와 같고 mmdetection 에서 가져왔다. 필요 한 부분에 주석을 달아 두었으니 위 내용과 비교하며 보자.

def quality_focal_loss(
          pred,          # (n, 80)
          label,         # (n) 0, 1-80: 0 is neg, 1-80 is positive
          score,         # (n) reg target 0-1, only positive is good
          weight=None,
          beta=2.0,
          reduction='mean',
          avg_factor=None):
    # all goes to 0
    pred_sigmoid = pred.sigmoid()
    pt = pred_sigmoid
    zerolabel = pt.new_zeros(pred.shape)
    # 아래는 negative sample 에 대한 loss
    loss = F.binary_cross_entropy_with_logits(
           pred, zerolabel, reduction='none') * pt.pow(beta) 
	
    label = label - 1
    pos = (label >= 0).nonzero().squeeze(1)
    a = pos
    b = label[pos].long()
    
    # positive goes to bbox quality
    pt = score[a] - pred_sigmoid[a, b]
    
    # positive sample 에 대한 loss 이고 target 으로 사용 되는 score 는 이 논문에서 제안한
	# iou 와 classification score 를 결합하여 soften 한 label 이다. 
    loss[a,b] = F.binary_cross_entropy_with_logits(
           pred[a,b], score[a], reduction='none') * pt.pow(beta)

    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

Distribution Focal Loss


 DFL은 참신 하면서도 이해가기 매우 쉽다.  DFL에 대하 설명 하기 전에 이 논문에서 bbox regression문제를 어떻게 정의 했는지 먼저 살펴보자. 이 부분이 꽤 참신하며 yolo등에서 bbox regression에 이 방식을 채용했다.

 이 논문은 anchor 또는 center point위치 로 부터 bbox의 각 변까지의 거리를 regression으로 직접 추론 하는 방식 대신 기대값을 구하기 위한 distribution을 추론 하는 방식으로 문제를 변경했다. 기존 bbox 추론 네트워크 들은 대부분 object 중심에 해당하는 anchor 에서 bbox의 각 변까지의 거리 (l,t,r,b)(아래 그림 참조) 에 해당 하는 4가지 값 scalar를 직접 추론 하는 방식을 채택했다. 이 방식을 굳이 수식으로 나타내면 특정 값에서만 확률 이 1인 Dirac delta function으로 나타낼 수 있기 때문에 이 논문에서는 기존 방식들은 distribution을 dirac delta로 가정 하고 문제를 풀었다고 말한다.

object 중심으로 부터 각 변까지의 거리 l,t,r,b

이 논문은 l,t,r,b를 직접적으로 추존 하는 대신 l,t,r,b의 확률 분포를 추론하고 이를 이용해 기대값을 계산함으로써 최종 l,t,r,b를 계산 한다. 수식으로 보면 아래와 같다.
$$\hat{y} = \sigma_{i=0}^{n}P(y_{i})y_{i} $$
위 식에서 $y_{i}$는 각변 까지의 거리 l,t,r,b의 discrete 한 값이고 $P(y_{i})$는 네트워크가 추론한 현 anchor에서 object boundary 까지의 거리 l,t,r,b가 $y_{i}$일 확률 값이다.

좀더 구체 적으로 예를 들자면 DFL은 object boundary 까지의 거리를 직접적으로 추론하는 것이 아니라 anchor 로 부터 object의 왼쪽 경계 까지 거리 $l$이 1일 확률 0.01, 2일 확률 0.05, 3일 확률 0.06, … 8일 확률 0.5, 9일 확률 0.2, … 16일 확률 0.01 이니까 $l$의 기대값은 XX이다! 라고 추론하고 이 기대값이 최종 추론 값이다.


여기서 당연한 의문이 생길 수 있는데 “확률 분포는 그렇다 치고 $P(y)$ 를 구해서 기대값은 구하려면 $y$를 당연히 알아야 하는데 이건 어떻게 구했나?” 가 그 의문이다.
논문의 저자들은 coco trainval135k 데이터 셋에서 bbox regression target $l,t,r,b$의 histogram을 구했다. 그 결과가 아래 histogram 이다. 이를 보면 x축인 regression target이 약 1~16까지 분포해 있는 것을 알 수 있다.
이 정보에 기반해 $ y \in [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]$ 으로 미리 정해 놓고 각 $y_{i}$가 현 anchor에 대응 하는 각 변까지의 거리일 확률 분포 $P(y_{i})$를 추론 하도록 했다.


이러한 확률 분포 $P(y)$를 fitting 하기 위해 이 논문은 DFL(distribution focal loss)를 제안 하는데 아주 간단하고 명확해 보인다. 식은 아래와 같다.
$$DFL(S_{i}, S_{i+1}) = -((y_{i+1}-y)\log(S_{i}) + (y-y_{i})\log(S_{i+1}))$$
target 을 y라 할때 y와 가장 가까운 값 $y_{i} \leq y \ge y_{i+1}$ 인 $y_{i}, y_{i+1}$에서 $P(y)$가 peak 값을 가지도록 위와 같은 complete form의 cross entropy 를 이용해 학습 한다. $S_{i}=\frac{y_{i+1}-y}{y_{i+1}-y_{i}}, S_{i+1} = \frac{y-y_{i}}{y_{i+1}-y_{i}} $ 이다. ($S_{i},S_{i+1}$ 이 왜 저렇게 되는지 궁금하면 linear interpolation을 생각해 보자)

예를 들어 실례를 들어 보자면 unnormalized logits
$ y_{i}=[9.38, 9.29, 4.48, 2.55, 1.30, 0.42, 0.03, -0.28, -0.51, 0.83, -1.27, -1.56, -1.78, -1.94, -.193, -1.38]$ 이라고 할때 target $y=0$ 이면 cross entropy 에 의해 $p(y_{0}), p(y_{1})$이 근처에서 peak 값을 갖도록 학습이 되는 방식이다.

코드로는 아래와 같다.  코드는 공식 mmdetection 에 삽입된 코드를 가져왔다. 

def distribution_focal_loss(
            pred, #normalized 되지 않은 y prediction 값이다. 즉 softmax 씌우기 전의 값
            label, # target label 이다.
            weight=None,
            reduction='mean',
            avg_factor=None):
    disl = label.long()
    disr = disl + 1

    wl = disr.float() - label  # y_{i} 에 대한 weight 계산
    wr = label - disl.float() # y_{i+1} 에 대한 weight 계산

    loss = F.cross_entropy(pred, disl, reduction='none') * wl \
         + F.cross_entropy(pred, disr, reduction='none') * wr
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

 

 

결과


이 논문에서 제안한 QFL, DFL의 효과에 대한 결과 이다. DFL보다는 QFL이 더 직접적으로 AP 향상에 영향을 미치는 것 처럼 보인다. QFL 만 적용 했을때가 DFL만 적용 했을때 보다 AP 컷다. 둘 모두 사용 하면 둘중 하나만 썻을 때 보다 AP 것을 알 수 있다.


fcos와 atss에 QFL을 적용한 결과 전반적인 성능향상을 보이는 것을 알 수 있다

DFL도 적용시 전반적으로 성능향상을 이끌어 낸다늩 것을 알 수 있다

 

참조: mmdetection 및 논문 원문


-끝-

'Deeplearning > Loss' 카테고리의 다른 글

[metric loss] additive angular margin loss  (0) 2021.11.26
[Metric learning loss] Triplet loss 설명  (0) 2021.11.24

 

 

문제 정의:

목표는 calibrated 된 카메라를 이용해 촬영된 인물의 head pose와 추정된 head pose의 uncertainty를 world coordinate frame에서 모델링 하는 것을 목표로 한다. 

+ 추가로 head pose(=face pose)를 이용해 인물의 시선을 예측하고 예측된 gaze point의 불확실성을 구하는 것까지 해본다.

 

노테이션 정의:

Pose $X^{frame}_{object}=(\psi, \phi, \theta, x, y, z) \in R^6$는  기준 coordinate frame 상에서 object의 pose 를 나타낸다. 

$$\psi : z axis 를 회전 축으로 하는 회전 각도, 단위:degree. $$

$$\phi : y axis 를 회전 축으로 하는 회전 각도, 단위: degree. $$

$$\theta : x axis 를 회전 축으로 하는 회전 각도, 단위: degree. $$

$$ x : object 의 x 좌표 $$

$$ y: object 의 y 좌표 $$

$$ z: object 의 z 좌표 $$

 

단, face coordinate frame 에서 face의 pose는 $X^{face}_{face}=(0,0,0,0,0,0)$ 으로 정의한다.

이 문제에서는 $X^{face}_{face}, X^{face}_{camera}, X^{camera}_{world} $ 3개의 pose 가 등장한다.

 

Pose 와 transformation matrix 관계

$$^{object}_{frame}H = f(X^{frame}_{object}) = \begin{bmatrix}^{object}_{frame}R&-^{object}_{frame}RT_{object}\\0&1\\ \end{bmatrix}$$

$$^{object}_{frame}R=Rz(\psi)Ry(\phi)Rx(\theta),  T_{object}=\begin{bmatrix} x\\y\\z \end{bmatrix}$$

$$p^{object} = ^{object}_{frame}H \times p^{frame}, p^{frame}=\begin{bmatrix} x\\y\\z\\1 \end{bmatrix}$$

$$^{world}_{face}H = ^{world}_{camera}H \times ^{camera}_{face}H$$

$$X^{frame}_{object}=f^{-1}(^{object}_{frame}H)$$

 

좌표계 정의

 

Pose 추정

$X^{face}_{camera} = SolvePnP(q^{face}_{i}, p^{image}_{landmark-i},K, dist)$

$X^{camera}_{world} = SolvePnP(q^{world}_{i}, p^{image}_{target-i}, K, dist)$

 

K: projection matrix, dist = distortion parameter

 

$p^{image}_{landmark-i}$ : image 좌표계 상의 i번째 face landmark 점 좌표 (x,y). 

$q^{face}_{i}$ : face coordinate frame 상의 $ p^{image}_{landmark-i}$ 대응점 (x,y,z).

$p^{image}_{target-i}$ : image 좌표계 상의 i번째 calibration target landmark 점 좌표 (x,y). 

$q^{world}_{i}$ : world coordinate 상의  $p^{image}_{target-i}$ 대응점 (x,y,z).

 

추정 오차 공분산 계산:

$p^{image}_{i} = f(X^{face}_{camera}) \times q^{face}_{i}$ 이 수식을 테일러 시리즈를 이용해 선형 근사 하면

$ \overline{p}^{image}_{i}+\Delta p = f( \overline{X}^{face}_{camera}+\Delta X) \times q^{face}_{i} $

 

$\Delta p \approx \frac{\partial f}{\partial X} \Delta X^{face}_{camera} = M_{i} \Delta X^{face}_{camera}$

이 식을 least square 로 풀면 Pose $X^{face}_{camera}$ 의 추정 오차 공분산은 

$C_{x} = E[\Delta X \Delta X^{T}] = (M^{T}M)^{-1}M^{T}E(\Delta p \Delta p^{T})((M^{T}M)^{-1}M^{T})^{-1}$

$E(\Delta p \Delta p^{T}) : ^{camera}_{face}H 를 이용해 projection 한 point의 공분산 행렬$ 로 나타낼 수 있다.

$C_{x}$는 6x6 사이즈의 행렬이다.

 

World-camera 의 추정 오차 공분산 행렬도 $C_{w} = 6 \times 6$  이고  위와 동일 한 방식으로 calibration target을 이용해 구할수 있다.

 

최종 $^{world}_{face}H$의 추정 오차 공분산을 $C_{y}$ 라 하면

$X^{face}_{world}=f^{-1}(^{world}_{face}H)$

$^{world}_{face}H= ^{world}_{camera}H \times ^{camera}_{face}H = f(X^{camera}_{world}) \times f(X^{face}_{camera})$ 관계에 의해 공분산 전파(propagation) 식에 의해 아래와 같이 구할 수 있다. 

 

$C_{y} = J_{X} C_{x} J^{T}_{X} + J_{W} C_{W} J^{T}_{w}, C_{y} = 6 \times 6 행렬$

 

$J_{x} = \frac{\partial f^{-1}(^{world}_{face}H)}{\partial X^{face}_{camera}} $

 

$J_{w} = \frac{\partial f^{-1}(^{world}_{face}H)}{\partial X^{camera}_{world}} $

 

이를 이용해 face pose 는 world에서 $X^{face}_{world}=f^{-1}(^{world}_{face}H)$ 를 평균으로 하고 $C_{y}$ 분산으로 하는 gaussian pdf 를 따른다고 볼 수 있다. 

 

다음으로 head pose로 부터 gaze point를 예측 해보자. 

여기서는 시선의 방향이 얼굴의 전면부 즉 코끝이 가리키는 방향과 같다고 가정한다. 

 

gaze point $g$를 아래와 같이 정의하자. 

$$g=\begin{bmatrix} x_{g}\\y_{g}\\z_{g} \end{bmatrix} = p_{face} + tV_{gaze} = g(^{world}_{face}H)      --------(func gaze)$$

$V_{gaze} =\begin{bmatrix} g_{x}\\g_{y}\\g_{z} \end{bmatrix}: gaze direction vector, ^{world}_{face}R의 마지막 컬럼, 즉 ^{world}_{face}H[:,2] = f(X^{face}_{world})[:,2]$

$p_{face} = \begin{bmatrix} x_{face}\\y_{face}\\z_{face} \end{bmatrix}$, 얼굴 위의 한점 여기선 face coordinate frame origin의  world coordinate frame 상의 좌표로 설정, 즉 $^{world}_{face}H[:,3]=f(X^{face}_{world})[:,3]$

 

이때 $V_{gaze}$는 시선의 방향을 나타네는 벡터 즉 시선의 방향 벡터라고 볼 수 있다. 가정에 의해 시선의 방향은 얼굴 평면과 수직(perpendicular) 이므로 world 좌표계에서 head pose를 나타내는 rotation matrix의 3번째 컬럼 즉 face coordinate frame의 z axis에 해당한다. 3차원 공간상에서 방향벡터 $V$와 평행하고 사람의 얼굴 위의 한점(눈 사이의 한점을 잡는게 가장 좋으나 여기서는 코끝으로 가정했다.)을 지나는 직선을 구하는게 목적이므로 $p_{face}$는 코끝의 world coordinate frame 상의 좌표로 가정하자. 

 

이렇게 하면 world coordinate frame 상의 z-x평면위에서 이미지 상의 특정 인물이 바라보고 있는 좌표(gaze point)와 불확실성은 아래와 같이 구할 수 있다.

$t= -\frac{p_{face}}{g_{y}} $ 로 설정 하면 y=0이 되므로 world 좌표계 상에서 x,z평면과 만나는 gaze point를 구할 수 있고

gaze point 와 pose $X^{face}_{world}$의 관계에 따라 gaze point 의 공분산은 

$C_{g} = J_{g} C_{y} J^{T}_{g} , 3 \times 3 행렬$

$J_{g} = \frac {\partial g(^{world}_{face}H)}{\partial g}, 3 \times 6 행렬$

로 구할 수 있다. 

 

 

다음은 face pose(=head pose) estimation에 사용된 landmark를 표시한 그림이다. 

 

'Deeplearning > toyproject' 카테고리의 다른 글

[Deskew for ocr] Rotation correction v2  (0) 2022.03.06
[Deskew for ocr] Rotation correction  (0) 2022.02.15

이번에 정리할 논문은 MobileViT로 transformer와 convolution을 조합해 더 좋은 feature를 학습할 수 있음을 보인 논문이다.

제목: MobileViT
link: https://arxiv.org/pdf/2110.02178.pdf


인트로 덕션 요약:

self-attention-based 모델은 컨볼루션 네트웍의 대안이다.
ViT 계열의 트랜드는 모델 파라미터를 늘려서 성능을 끌어 올리는 방식인데 이는 모델 사이즈를 키우고 latency를 증가시킨다. 따라서 edge device 같은 자원이 제한적인 환경에서 이는 문제가 된다.

mobile device의 자원 제약을 충족할 정도로 Vit 모델 사이즈가 줄어 들 수 있지만 DeIT(Touvron et al., 2021a) 의 경우에서 보이듯 light-weight CNN(MobileNetV3)보다 3% 낮은 성능을 보일 정도로 성능하락이 심하다.

Vit 계열의 문제점
1. 모델이 무겁다. ex) ViT-B/16 vs mobileNetV3: 86 vs 7.5 million parameters
2. 최적화 하기 어렵다(Xiao et al., 2021 )
3. 오버 피팅을 방지 하기 위해 광범위한 Data augmentation 과 L2 regularization 이 필요하다. (Touvron et al., 2021 a; wang et al., 2021)
4. dense prediction task ex) segmentation,과 같은 down-stream task 를 위해 무거운 decoder가 필요하다. (ex. Vit-base 세그멘테이션 네트웍(Ranftl et a. 2021.,) 은 345 million parameters를 사용해 CNN계역의 DeepLabv3(59 million parameters)와 비슷한 성능을 얻었다.)

ViT 가 비교적 많은 파라미터를 필요로 하는 것은 CNN 계열이 선천적으로 지니고 있는 image-specific inductive bias가 부족하기 때문으로 보인다. (Xiao et al., 2021)

강인한 고성능 Vit 모델 개발을 위해 convolution과 transformers의 결합이 주목받고 있다. (xiao et al., 2021, d'Ascoli et al., 2021; Chen et al., 2021b). 그러나 이런 모델은 여전히 무겁고 data augmentation 에 예민하다.
증거로, Cutmix와 DeIT-style data augmentation 을 제거하면 Heo et al(2021)은 imageNet accuracy가 78.1% 에서 72.4% 로 감소한다.

이러한 이유로 CNN과 transformer의 강점을 조합해 mobile vision task를 위한 ViT 모델을 만드는 것은 여전히 숙제로 남아있다.(본인들 연구 정당성을 이렇게 정성들여 조목조목 주장할 수 있는 건 정말 큰 능력인것 같다.)

이 논문에서 저자는 light-weight, general-purpose, low-latency이 3가지 관점에 초점을 맞춰 mobilevit라는 모델을 제안한다.

여기서 low-latency와 관련해 FLOPs라는 지표가 적절하지 않다고 언급하는데 이유는 다음과 같다.

FLOPs는 아래와 같은 latency에 영향을 미치는 요소들을 고려하고 있지 않다.
1. 메모리 엑세스
2. 병렬성 정도
3. 플랫폼 특성
그 증거로 PiT는 DeIT 에 비해 1/3배의 FLOPs 를 가지고 있으나 iPhone-12에서 비슷 한 latency를 보여준다.
(10.99ms vs 10.56ms)
따라서 이 논문은 FLOPs 측면에서의 최적화를 목표로 하지 않는다. (즉, FLOPs는 상대적으로 큰 편이다...)

정리: MobileVit는 CNNs의 장점인 spatial inductive biases 와 less sensitivity to data augmentation 과
ViT의 장점인 input-adaptive weighting and global processing을 잘 조화했다.(가 주장하는 바이다.)

특히 MobileVit block를 제안하는데 이 블럭은 local 과 global information 모두를 tensor 효율적으로 encode 한다. (한글로 풀어 쓰고 싶은데 표현하기 정말 난해 하다...)
Convolution을 사용/사용 하지 않는 ViT 계열과는 다르게, mobileVit는 global representation을 학습하는데 다른 관점을 지니고 있다.
standard 컨볼루션은 3개 연산으로 이루어 진다. unfolding, local processing, folding.
MobileVit Block은 local processing을 transfomer를 이용한 global processing으로 치환해서 CNN과 ViT의 특성을 가진다. 그로인해 적은 수의 파라미터와 간단한 학습 레시피(data augmentation이 ViT에 비해 간단함을 의미) 더 좋은 feature를 배운다.

결과적으로 MobileViT는 5~6million의 파라미터 수로 mobileNetV3보다 3.2% 더 높은 Top-1 accuracy를 ImageNet-1k dataset에서 보여 준다.

기존 모델들의 한계:

light-weight cnn:

공간적으로 local information에 의존한다. Convolution 필터 자체가 인접 픽셀들간의 관계로 부터 representation을 학습하니 당연하다.
다만 layer가 깊어 질수록 receptive field가 커짐에도 불구하고 이렇게 표현한 걸 보면 상대적으로 local 하다는 표현인것 같다.


Vision transformer:

많은 파라미터를 사용한다.
데이터가 크지 않으면 오버피팅 문제가 있다.
extensive data augmentation 이 필요하다.
convolution을 이용하는 ViT 모델들도 있지만 여전히 heavy weight이고 light-weight CNN 모델들 수준으로 파라미터 수를 줄이면 CNN 계열 보다 성능이 떨어지는 문제를 한계로 지적했다.

ViT 계열은 image-specific spacial inductive bias가 부족 하기 때문에 더 많은 파라미터와 data augmentation이 필요하다고 언급하는데 convolution을 ViT에서 이용 함으로서 이 부족한 inductive bias를 ViT 모델에 이식할 수 있다고 한다.

논문에서 언급된 여러 모델들이 convolution을 ViT 모델에 서로 다른 방식으로 이용해 CNN과 ViT의 장점을 이용해 강건하고 고성능의 ViT 만들었다고 한다. 하지만 여전히 남은 문제는 "convolution과 ViT를 어떻게 조합해야 둘의 강점을 잘 이용 할수 있는가?" 라는 점이다.

위와 같은 문제를 가지고 있는 기존 방식에 비해 MobileVit의 강점은 :

1. Better performance(주어진 파라미터 한계에서 light-weight CNN 보다 높은 성능을 보임)
2. Generalization capability: 다름 ViT variant 모델들 모다 generalization capability가 좋음
3. Robust: Data augmentation과 L2 regularization에 덜 민감함

MobileViT block:

이 논문의 핵심은 MobileViTBLock을 통해 적은 수의 파라미터로 local 과 global information을 modeling 하는 것이다.

Fig 1. MobileViT Diagram

Fig 1 의 녹색 영역은 MobileViTBlock의 연산을 도식화한 것이다.
MobileVitBlock에서 local spacial information ( Fig 1 녹색 영역의 Local representations 로 표시된 부분)
$n \times n$ convolution을 이용해 획득하고 point-wise convolution을 이용해 인풋 텐서 채널의 선형 조합으로 고차원 정보를 생성한다. --------------------------------------------------------------------------------------(1)

global information은 multihead attention을 이용해 학습하게 된다.

input을 $X_L$이라 할때 $X_L$을 서로 오버랩 되지 않는 patch $X_U = R^{P \times N \times d}$ 로 나타낸다.
$P = w \times h$ 로 h,w는 패치의 가로,세로 크기 이고, $N=\frac{HW}{P}$은 input tensor내에 존재하는 패치의 수이다.

transformer가 local spatial information과 inter_patch 정보를 모두 모델링 하기 위해서는 h<=n, w<=n 을 반드시 만족해야 한다. 즉 non-overlap 패치의 크기는 convolution kernel의 크기보다 반드시 같거나 작아야 한다.

아래 그림 Fig 2는 MobileViTBlock 이 왜 local 과 global representation을 모두 학습 할 수 있는 지 보여준다.
검은 색 굵은 선으로 분리된 각 영역은 오버랩되지 않은 patch이고, 각 patch 내 회색 선으로 분리된 공간은 patch에 속하는 pixel 이라 보면 된다. 왼쪽 하단의 patch를 보면 파란색 화살표가 파란색으로 표시된 pixel 주변의 정보를 취학하는 것을 표현 했는데 이는 $ n \times n$ convolution을 통해 이루어 진다. 아래 그림의 정가운데 위치한 patch 내에 붉은 색으로 표시된 pixel은 다른 patch의 자신과 동일 한 위치의 pixel들의 정보를 취합 하는데 이과정이 multi-head attention을 통해 이루어 진다.

Fig 2. How local and global representation are learned by MobileVitBlock

이를 좀더 수식 적으로 표현 하면 $p = {1,...,P}$ 라 할때 패치간의 관계는 모델은 다음과 같이 표현 할 수 있다.
$$ X_{G}(p)=Transformer(X_{U}(p)), 1<=p<=P$$
ViT와 다르게, 위와 같이 모델링 되는 MobileVit는 패치간의 순서나 pixel의 공간 정보도 잃지 않기 때문에 MobileVitBlock의 입력 $X_{L}$의 각 픽셀 위치에 해당하는 local, global information이 축약된 정보를 입력 모양에 맞게 $X_{F}=R^{H \times W \times d}$로 복원가능 하다.
그 후 $X_{F}$는 point-wise-convolution을 이용해 저차원으로 projection 되고 mobileVitBlock의 입력 X와 concat된 후 $n \times n$ convolution을 이용해 feature를 혼합한다.

$X_{U}$은 $n \times n$ convolution으로 local information을 학습하고 $X_{G}$는 패치간의 정보를 취합하기 떄문에 $X_{G}$의 각 픽셀은 결국 입력 X의 모든 정보를 취합한다고 볼수 있다.
(코드 상으로 보면 X_G의 임의의 패치 p에 속한 임의의 pixel i(hi,wi)는 X_G의 오버랩 되지 않는 p가 아닌 다른 패치들의 q에 속한 pixel j(hi,wi)위치와 attention이 계산되므로 엄밀히 말하면 모든 픽셀에 대한 정보를 취합하는 것이 아니라 다른 패치 내의 동일한 위치에 있는 pixel의 정보만 취합한다. 따라서 입력 X의 모든 픽셀에 대한 정보를 X_G가 취합 하게 하려면 nxn convolution의 커널 사이즈 n와 , 패치 사이즈 h,w를 신중하게 디자인 해야 한다. )

코드로 보면 더 직관 적이니 아래 코드를 보자.
각 라인 옆에 위 설명의 $X_{L}$, $X_{U}$, $X_{G}$에 해당 하는 값을 표기해 두었다.

class MobileViTBlock(nn.Module):
    def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size

        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        self.conv2 = conv_1x1_bn(channel, dim)

        self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)

        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
    
    def forward(self, x):
        y = x.clone()

        # Local representations
        x = self.conv1(x)
        x = self.conv2(x) # 여기서 x = X_L 이다.
        
        # Global representations
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        x = self.transformer(x) # 파라미터 x = X_U 이고 output x = X_G이다. 
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
		# 바로 윗 라인의 reshaped 된 x = X_F에 해당 된다. 
        
        # Fusion
        x = self.conv3(x)
        x = torch.cat((x, y), 1)
        x = self.conv4(x)
        return


relationship to convoltuions:

컨볼루션은 아래와 같은 3개 연산의 스택으로 볼 수 있다. Fig 1 의 녹색 영역을 참고해서 보면 좀더 쉽게 이해 할 수 있다.
1. unfold
2. matrix multiplication
3. fold
본 논문에서 제시한 mobileVitBlock는 아래와 같은 3단계 연산을 한다는 측면에서 transformer를 convolution 처럼 본다고 볼수 있다.
1. 입력 $X_{L} \in \in R^{H \times W \times d}$을 non-overlaping patch $X_{U} \in R^{P \times N \times d}$로 unfolding 하고
-> convolution의 unfolding에해당
2. transformer layer를 통해 global representation을 배우고
-> convolution의 matrixmultiplication에 해당
3. transformer의 output인 $X_{G} \in R^{P \times N \times d}가 patch order와 pixel order를 잃지 않았기 때문에 $X_{F} \in R^{H \times W \times d}$로 복원
-> convolution의 folding에 해당

개인 적으로 transformer를 convolution 처럼 볼수 있다는 해석 보다 non-overlaping 패치 간의 feature relation을 모델링 하기 위해 swin transformer 가 axis roll 을 사용 한것 보다 본 논문에서 convolution과 transformer 블럭을 사용해서 local, global feature representation을 학습하는 전략이 더 간단하고 효과적으로 보인다.

light-weight:

기존 convolution + transformer 사용 네트워크가 heavy했기 때문에 같은 레이어의 조합을 사용하는 MobileViT가 왜 light-weight 이 가능 한지 의문이 들 수 있다. 저자는 그 이유를 다음과 같이 말한다.
기존의 convolution과 transformer를 사용하던 네트워크는 spatial information을 latent로 바꾼다.
이게 무슨 말인고 하니 아래 그림 Fig 3을 보자.

Fig 3.



transformer 적용시 인접 픽셀을 채널축으로 stack 하고 픽셀 값들을 linear 연산을 이용해 latent로 보내는 embeding 연산이 image-specific inductive bias를 잃게 하는데 반해
MobileViT는 convolution과 transformer를 convolution의 특징을 살리면서 global representation을 배울 수 있는 방향으로 사용하기 때문에 light weight이 가능 하다는 입장이다.

Multi-head self-attention의 계산 복잡도를 비교해 보면
MobileViT: $O(N^{2}Pd)$
ViT: O(N^{2}d)
로 MobileViT가 더 비효율 적으로 보인다. 하지만 실제로는 MobileViT가 DeIT보다 약 1/2배의 FLOPs 를 가지고 ImageNet-1K 에서 1.8%더 높은 accuracy를 보였다.
이럴수 있는 이유 역시 convolution과 transformer를 서로의 장점을 살리는 방식으로 조합했기 때문에 가능했다는 것이 논문의 주장이고 결과가 좋으니 맞는 말로 보인다.

Multi-Scale Sampler For Training Efficiency:

MobileViT를 학습 시키기 위해 multi scale training 전략을 사용 했는데 기존 방식과 크게 두가지 다른점 이 있다.
이 두가지 다른점은 기존 multi-scale training 방식의 다음과 같은 단점을 보완한 것이다.

1. ViT 계열 네트워크는 multi scale training을 위해 각 scale 별로 네트워크를 fine tuning 하는 방식을 취한다. 왜냐면 ViT의 positional embedding이 입력 이미지 사이즈에 따라 interpolation 되어야 하고 네트워크의 성능이 이 positional embedding의 interpolation 방식(?) 에 영향을 받기 때문이다.

2. CNN 계열 네트워크들은 학습 중 미리 정해진 입력 이미지의 사이즈 $set S=((H_{0},W_{0}), ..., (H_{n}, W_{n}))$중 하나를 정해진 iteration 마다 선택해 학습에 활용하는데 이렇게 하면 batch size가 가장 큰 입력 이미지 사이즈에 의해 고정적으로 결정되기 때문에 작은 입력 이미지 사이즈를 이용해 학습할 때는 GPU 사용율이 떨어질수 밖에 없다. Fig 4 (a) 의 Multi-scale sampler를 보면 여기서 지적한 문제를 도식화 했다. Standard sampler에 비해 gpu memory utilization이 떨어지는 것을 표현했다.

위 두 가지 문제를 해결하는 multi scale sampling scheme은 다음과 같다.
1. MobileViT의 경우 positional embedding이 필요 없기 때문에 파인튜닝 방식으로 multi scale training scheme을 사용 할 필요 없이 CNN 계열 네트워크 처럼 학습 중 기 정해진 방식으로 multi scale training을 사용 한다.

2. 기존 방식에 존재 하던 비효율적 batch size 선택을 개선하는 방법으로 가장 큰 입력 이미지 사이즈를 $(H_{n}, W_{n})$, 이때의 batch size 를 b 라 할때
i-th iteration의 배치 사이즈 $b_{t}$ 는 $b_{t} = \frac{H_{n} * W_{n} * b}{H_{i}*W_{i}} 를 사용하게 함으로써 gpu utilization 문제를 해결한다.

결과적으로 위와 같은 2가지 보완점을 적용한 multi-scale sampler를 사용 할 경우 Fig 4 (b) 같이 standard sampler 보다 학습 효율이 좋아지는 것을 볼 수 있다.

Fig 4. About Multi scale sampler

Experimental result:


Dataset: image classification on The ImageNet-1k Dataset
implementation details:
MobileVit를 from scratch로 imageNet-1k 에 학습 시킴, 1.28 million training image, 50,000 validation image.
GPU: 8 NVIDIA GPU 사용
framework: pytorch
batch size: 1024
epoch: 300
기타 :
label smoothing cross-entropy(smoothing=0.1),
multi-scale sampler(S={(160,160), (192,192), (256,256), (288,288), (320, 320)},
L2 waeight dacay( 0.01),
lr scheduler: consine annealing (0.0002~0.002 warm start)

CNN모델들과 비교 :

Fig 5 CNN 계열 네트워크와의 비교


Fig 5에서 볼 수 있듯 MobileViT는 light-weight CNN들 보다 우수한 성능을 보여준다. Fig 5 (b)를보면 MobileViT-XS모델은 파라미터 수는 가장 작지만 top-1 acc 는 가장 높은 성능을 보여 준다. 심지어 Fig 5 (c)를 보면 heavy-weight CNN 네트워크와 MobileViT-S 모델을 비교 해도 상대적으로 작은 파라미터 수로 높은 top-1 acc 성능을 보여 준다.

ViT와 비교:

Fig 6. ViT 계열 네트워크와의 비교

CNN 계열 네트워크 들과 비교 결과와 거의 유사하다. Fig 6 (a), (b)에서 볼수 있듯이 MobileViT-XS, MobileViT-S 모델은 다른 ViT 모델들 보다 상대적으로 적은 파라미터 수로 더 좋은 top-1 acc 성능을 달성했다.
비교를 위해 ViT 계열 모델들 학습시 advenced augmentation사용, distillation 비사용, MobileVit는 basic augmentation을 사용했다.

Object detection 과 Segmentation:

아래 그림 Fig 7에서 보듯이 MobileViT는 object detection 과 segmentation task에서도 backbone으로서의 역할을 수행 할수 있고 성능 또한 MobileNetV2에 비해 좋은 결과를 보여준다. 단 inference time이 mobileNetv2가 압도적인 것으로 보이는 데 이것은 MobileNetV2 의 연산은 하드웨어 최적화가 잘되어있기 때문으로 논문의 appendix에 이 부분에 대한 분석이 포함되어있다. 꼭 읽어 보길 바란다. (사실 이유는 이미 언급했듯이 mobilenetv2는 하드웨어 최적화가 잘되어있기 때문으로 MobileViT의 연산을 효율적으로 지원하는 하드웨어 가속기가 존재한다면 MobileViT의 속도도 훨씬 빨라질 것이라고 하는데 이는 모든 구조가 다 마찬가지 아닌가 하는 생각..)

Fig 7 테스크 별 MobileViT 적용 성능


생각 해 볼만 한 사실들

Patch Size:

아래 그림 Fig 8 은 patch size 에 따른 inference 속도 변화 와 acc 성능 변화를 보여준다. patch size 에 따라 inference 속도와 acc 성능 변화가 나타나니 응용에 따라 주의 깊은 튜닝이 필요할 것 같다.

Fig 8 패치 사이즈 별 inference 속도와 분류성능 관계 patch 사이즈는 각각 32x32,16x16,8x8 의 spatial level에 해당


$nxn$ convolution kernel size 와 patch size ($ h \times w$)의 관계:

아래 그림 Fig 9에서 n은 convolution kernel의 크기, h,w는 patch 의 크기를 나타낸다.

Fig 9 patch 사이즈와 convolution kernel 사이즈의 관계

$h or w > n$ 일 경우 Fig 9 (c) patch 내의 각 pixel은 해당 convolution을 통해 patch 에 속한 모든 pixel의 정보를 취합 할 수 없으므로 local information의 취합 능력이 떨어진다. 이는 곳 전체 적인 성능 하락으로 이어진다.


Inference speed:

논문의 제목에서 알 수 있듯이 MobileViT는 edge device 에서 효율과 성능이 좋은 network를 목표로 했다.
하지만 아래 표 에서 보이듯 실제 iPhone 12 cpu, iPhone12 neural engine, NVIDIA V100GPU 에서 inference time을 비교 해보면 MobileNetv2가 위 언급된 모든 device에서 가장 빠르다.


iPhone에선 MobileViT는 DeIT, PiT 보다는 빠르지만 GPU 에서는 DeIT, PiT 가 오히려 빨랐다. GPU에서 MobileViT가 DeIT, PiT 보다 느린 이유는 1) MobileViT 모델이 shallow 하고 narrow 한 특성이 있고, 2) 256x256 이라는 좀더 큰 해상도 (DeIT는 224x224)로 동작 하기 때문이라고 한다. 또한 MobileViT 의 MobileViTblock 에서는 unfolding, folding 연산(Fig 1 참조)이 수행되는데 V100에서 이 두 연산을 gpu -accelerated operation을 사용 하는지 안하는 지에 따라 그 결과가 다르다. gpu -accelerated operation 을 사용 하지 않을 경우 0.62ms 이 걸리고 사용할 경우 0.47ms 이 걸린다. 그리고 MobileNetv2의 inference 속도가 빠른 이유는 mobileNetv2를 구성하는 연산을 서포트 하는 하드웨어 가속기의 덕분일 것으로 본다. MobileViT에 사용 되는 연산들이 하드웨어에 최적화되게 구현된다면 mobieViT의 inference 속도도 더 높아 질수 있을 것이라고 저자는 말한다.


-끝-

참조:
code link: https://github.com/chinhsuanwu/mobilevit-pytorch

GitHub - chinhsuanwu/mobilevit-pytorch: A PyTorch implementation of "MobileViT: Light-weight, General-purpose, and Mobile-friend

A PyTorch implementation of "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer". - GitHub - chinhsuanwu/mobilevit-pytorch: A PyTorch implementation of "...

github.com

이전 글 [Deskew for ocr] Rotation correction 에서 regression으로 회전된 문서를 바로잡는 방식을 시도해봤다.

 

이전 글에서 쓴 모델의 문제는 크게 2가지

1. 회전되지 않는 문서를 회전된 문서로 오인식해 오히려 이상하게 회전시키는 문제가 발생한다.

2. 정확도가 굉장히 떨어진다.

 

이를 해결하기 위해 새로운 방식을 조사, 적용 했고 결과가 좋아서 공유한다.

 

이번엔 회전된 문서를 바로잡는 문제를 분류 문제로 정의 하고 해결해 본다. 이 방식은 ocropus3 를 참고 했다.

 

코드는 아래 링크의 devStream_fft branch 에 있다.

link: https://github.com/pajamacoders/ocrDeskew/tree/devStream_fft

문제 정의

문서가 얼마나 회전되어있는 지를 분류 문제로 정의 하고 풀기 위해서 각 회전의 정도에 class를 부여 해야 한다.

 

나는 0.5도 단위를 하나의 클래스로 정의했다.

예를 들자면 아래와 같은 방식이다. 문서가 회전된 각도를 degree 라고 표현 했을 때 회전의 정도(degree) 가 -1 도에서 -0.5 도 이내이면 class 0에 배정 하는 방식이다.

range -1< degree<-0.5 -0.5<= degree <0 0<= degree <0.5
class 0 1 2

이러한 방식으로 -89~89도 사이에서 회전된 문서의 rotation correction 문제는

356 클래스를 가지는 분류 문제로 정의 할 수 있다.

 

개발 환경

개발 환경은 ngc repo에서 아래 이미지를 다운 받았다.

docker image : nvcr.io/nvidia/pytorch:22.01-py3

train metric tracking: mlflow

 

전처리

전처리 과정은 이전 글 [Deskew for ocr] Rotation correction 의 전처리와 거의 유사하지만 rotation의 정도에 class를 대응 하는 부분만 차이가 있다.

 

이전 글에서 바뀐 GT 생성 부분인 RandomRotation class 는 아래와 같다.

 

class RandomRotation(object):
    def __init__(self, ratio, degree, buckets=None):
        self.variant = eval(degree) if isinstance(degree, str) else degree
        self.ratio = eval(ratio) if isinstance(ratio, str) else ratio
        self.buckets = eval(buckets) if isinstance(buckets, str) else buckets

    def __call__(self, inp):
        if  np.random.rand()<self.ratio:
            deg = np.random.uniform(-self.variant, self.variant-0.1)
            img = inp['img']
            h,w= img.shape
            matrix = cv2.getRotationMatrix2D((w/2, h/2), deg, 1)
            dst = cv2.warpAffine(img, matrix, (w, h),borderValue=0)
            inp['img'] = dst
        else:
            deg = 0

        if self.buckets:
            rad = np.deg2rad(deg)
            range_rad = np.deg2rad(self.variant)
            bucket = int(self.buckets * (rad+range_rad) / (2*range_rad))
            inp['rot_id'] = bucket # 이 값이 문서가 회전된 정도를 class에 할당 한 값 즉 target class 값이다.

        inp['degree'] = deg
        return inp

 

모델

이번 모델의 특이점은 중간에 fft를 사용 하는 layer 가 들어간다는 것이다.

 

모델은 아래와 같다.

class DeskewNetV4(nn.Module):
    def __init__(self, buckets, last_fc_in_ch, pretrained=None):
        super(DeskewNetV4, self).__init__()
        buckets = eval(buckets) if isinstance(buckets, str) else buckets
        assert isinstance(buckets, int), 'buckets must be type int'
        k=5
        self.block1 = nn.Sequential(
                ConvBnRelu(1,8,k,padding=k//2),
                nn.MaxPool2d(2,2), #256x256
                ConvBnRelu(8,16,k,padding=k//2),
                nn.MaxPool2d(2,2), #128x128
            )
        
        self.block2 = ConvBnRelu(16,8,k,padding=k//2)
        self.fc = nn.Sequential(nn.Linear(131072,last_fc_in_ch, bias=False),
            nn.BatchNorm1d(last_fc_in_ch),
            nn.ReLU(True),
            nn.Linear(last_fc_in_ch, buckets, bias=False))

        self.__init_weight()
        if pretrained:
            self.load_weight(pretrained)

    def forward(self, x):
        out = self.block1(x)
        out = torch.fft.fft2(out)
        out = out.real**2+out.imag**2 
        out = torch.log(1.0+out)
        out = self.block2(out)
        bs,c,h,w = out.shape
        out = out.reshape(bs,-1)
        out = self.fc(out)

        return out

 

결과

모델의 last_fc_in_ch의 값으로 128을 사용 할 경우 아래와 같은 결과를 얻었다.

학습은 총 800에폭을 돌렸는데 굳이 이럴 필요까진 없었다.

optimizer로 adam을 사용 했고 lr_schedule은 cosineannealing을 lr range  1e-3 ~1e-6으로 사용 했다.

 

ce_loss:0.1241

precision:0.9466

recall:0.9426

f1_score:0.9427

 

아래는 train, validation 시의 f1 score의 값을 나타낸 그래프 이다.

200epoch 쯤 되면 성능 향상은 거의 없는 것을 볼 수 있다.

아래 그림은 임의의 숫자를 적은 문서로 테스트 한 결과 이다.

왼쪽이 입력으로 들어간 회전된 문서이고 오른쪽인 inference로 회전을 바로잡은 결과이다.

regression 모델 보다는 전체적인 결과가 훨씬 좋다.

 

- 끝 -

배경

이번 포스팅에서는 ocr 성능을 높이기 위해 간단한 전처리를 통해 회전된 문서를 올바르게 돌려 주는 문제를 풀어보고자 한다.

 

OCR 은 optical character recognition의 약자로 hand writing, 인쇄된 문서 등을 카메라로 찍거나 스캔했을때 그 문서 내의 글자를 인식해 문서를 전산화 할때 자주 사용되는 기술이다.

 

이때 ocr의 성능을 떨어뜨리는 문제 중 하나는 입력으로 들어오는 문서가 회전되어있는 경우 이다.

문서가 회전된 상태로 스캔되면 각 문자 자체는 인식이 할수 있지만 문자를 단어로 머지(merge)하는 과정이나 숫자의 연속인 여권번호, 운전면허 번호, 주민등록 번호 등 긴 문자의 경우 하나의 시퀀스로 머지 되지 않아 잘못된 패턴으로 인식 되는 경우가 종종 발생해  최종 인식률이 떨어질 수 있다. 

 

이런 문제를 해결하기 위해 인식 네트워크 자체가 문서 얼라인먼트 기능을 가지도록 설계할 수도 있고, hand craft feature extraction을 기가 막히게 설계 하고 이를 이용한 homography 계산 알고리즘을 설계하는 방식도 있겠으나!!!

여기서는 간단한 네트워크 설계하고 학습해 rotation correction하는 방식을 시도해 본다.

 

* 이 포스팅에 사용된 문서는 공공문서 양식 중 하나이고 이런 DB는 'aihub -> 음성/자연어->공공 행정문서 OCR'에서 아주 손쉽게 구할수 있다.

 

이 포스트에서 개발한 코드는 이 링크에 있다.

code :https://github.com/pajamacoders/ocrDeskew

목표

Fig 1. 왼쪽은 회전된 문서, 오른쪽은 회전 없이 정상적으로 스캔된 문서

이 포스팅의 목적을 명확히 나타내는 그림이 Fig 1. 이다. Fig 1 의 왼쪽은 회전된 문서를 보여준다. 우리의 목표는 이렇게 회전된 문서를 오른 쪽과 같이 회전이 없는 상태로 만드는 것이다.

 

문제 정의

나는 이 포스팅의 목표인 rotation correction 문제를 image orientation prediction 문제로 정의 했다.

 

사고의 흐름은 입력 이미지로 부터 이미지의 회전 정도(orientation)를 구하면 그 회전의 크기만큼 반대 방향으로 회전을 시켜 줌으로서 이미지를 정상적으로 만들수 있기 때문이다.

 

목표 추정 범위는 -30~30degree 이내의 회전으로 정했다.

회전의 크기를 일정 step으로 양자화 해서 각 구간에 class를 부과해 classification으로 해결 할 수도 있을것 같지만

회전 크기를 degree로 바로 추정하는 regression 문제로 정의 하고 풀고자 한다.

 

개발 환경

개발 환경은 ngc repo에서 아래 이미지를 다운 받았다.

docker image : nvcr.io/nvidia/pytorch:22.01-py3

train metric tracking: mlflow

 

전처리

전처리는 과정에서는 크게 아래 4가지를 수행했다.

  1. Color conversion(BGR 2 GRAY)
  2. Resize
  3. Rotation (이 과정은 target GT를 생성하는 과정이다.)
  4. Normalization

이때 resize시에 특정한 문제가 발생했고 이를 해결하기 위해 약간의 꼼수(=테크닉)가 사용되어 그 부분에 대해 잠시 설명 하고자 한다.

 

입력 이미지가 대부분 A4등의 문서를 스캔한 것이기 때문에 풀고자하는 문제에 비해 불필요 하게 크다고 생각했고 입력 이미지의 크기를 512x512로 정의 했다.

하지만 입력 중 어떤 이미지들은 3508x2480(width x height)의 크기를 지녔고 이걸 (512x512)로 리사이즈 하니 다음과 같은 이미지가 생성되었다.

Fig 2. 회전된 이미지

예제의 이미지는 그나마 좀 나은 편인데 심한 것들은 의미있는 글자나 선등이 sampling이 거의 안되어 노이즈 처럼 점만 뿌려져 있는 것처럼 보이기도 한다.

문제는 이러한 텍스트 문서의 경우 배경에 비해 정보를 담고 있는 글자나 테이블을 이루는 선등의 면적이 매우 작아 리사즈 할때 텍스트나 선 부분에서 샘플링되는 양이 매우 작기 때문이다.

 

위에 서 말했듯이 나는 원인이 정보를 포함하는 텍스트나 선이 차지하는 면적이 작기 때문으로 보았기 때문에 리사이즈 하기 전에 이 정보를 포함하는 부분의 면적을 키우기로 했다.

 

가장 먼저 떠오른 것은 dilate 연산이다.

dilate을 적용하기 위해 이미지의 색을 반전 시켜 준다. 색 반전의 이유는 입력 이미지는 흰색의 배경을 가지는데 기본적으로 dilate연산은 intensity가 큰 부분은 확장하는 효과를 가지기 때문에 내가 면적을 키우고자 하는 텍스트와 선 등 정보를 가진 pixel이 큰intensity 가지고 background가 낮은 intensity를 가지게 하기 위해서 이다.

아래 코드와 같이 이미지를 읽고 gray scale로 색변환 한 후 background와 foreground의 색상을 반전 시켰다.

    img = cv2.imread(self.img_pathes[index])
    # transform
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    if self.inverse_color:
        img = 255-img # inverse image to apply dilate
    res_dict = {'img':img, 'imgpath':self.img_pathes[index]}

다음으로 적용한 테크닉은 3508 x 2480 을 곧바로 512x512로 리사이즈 하는게 아니라

dilate ->1/2 resize -> 1/2 resize-> resize to 512x512 로 리사이즈 스탭을 진행해 orientation 추출에 필요한 정보를 최대한 많이 보존 하게 하는 것이다.

이 부분의 코드는 다음과 같다.

class Resize(object):
    def __init__(self, scale=4):
        assert (scale!=0) and (scale&(scale-1))==0, 'scale must be power of 2'
        self.iter = np.log2(scale).astype(int)
    
    def __call__(self, inp):
        img = inp['img']
        h,w = img.shape
        inp['org_height']=h
        inp['org_width']=w
        k = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
        img = cv2.dilate(img, k)
        for i in range(self.iter):
            h, w = h//2, w//2
            img=cv2.resize(img, (w, h), interpolation=cv2.INTER_CUBIC)
            h, w = img.shape
        inp['img']=cv2.resize(img, (512, 512), interpolation=cv2.INTER_CUBIC)
        return inp

이렇게 만들어진 이미지는 아래 Fig 3과 같다.  Fig 2와 텍스트와 선분의 정보 손실 정도를 비교해 보자.

리사이즈 방식에 따라 최종 목적인 regression 성능이 얼마나 달라지는지 ablation study를 진행해 보지는 않았지만 정보의 손실의 최소화 하고자 하는 측면에서는 이게 맞지 않나 싶다.

Fig 3. resize로 인한 글자 정보손실을 최소화한 이미지

또 다른 방식으로는 텍스트의 뭉개짐을 무시하고 높은 intensity를 가지는 부분을 최대화 하고자 한다면 아래와 같이 입력을 변환 하는 방식도 가능 할 것이다.

Fig 4.

Fig 4에서는 글자를 알아 보기는 힘들지만 이미지의 orientation을 결정하는데에 foreground 정보의 양이 중요할 경우 유용 할 것이다. Fig 4는 위 Resize 클래스 에서 cv2.dilate() 함수를 for 문 안에 cv2.resize함수 호출 바로 앞으로 옮긴 경우 만들어 지는 이미지 이다.

 

GT 생성

위 이미지 전처리 과정에서 Rotation 과정은 아래와 같은 이유에서 추가되었다.

 

이미지가 회전된 정도를 GT로 만들어 놓은 경우는 거의 없기 때문에 정상적인 이미지를 opencv를 이용해 랜덤으로 -30~30degree 사이 각도로 회전시켜 회전된 정도를 GT로 사용 한다.

 

아래코드는 resize된 이미지를 임의의 각도로 회전하고 회전한 각도를 target GT로 저장하는 코드이다.

 

class RandomRotation(object):
    def __init__(self,degree):
        self.variant = eval(degree) if isinstance(degree, str) else degree

    def __call__(self, inp):
        deg = np.random.uniform(-self.variant, self.variant)
        img = inp['img']
        h,w= img.shape
        matrix = cv2.getRotationMatrix2D((w/2, h/2), deg, 1)
        dst = cv2.warpAffine(img, matrix, (w, h),borderValue=0)
        inp['img'] = dst
        inp['degree'] = deg
        return inp

 

모델 설계

입력의 회전된 각도를 추정하는 아주 가벼운 네트워크를 아래와 같이 구성했다.

import torch
import torch.nn as nn


class ConvBnRelu(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1):
        super(ConvBnRelu, self).__init__()
        self.conv_bn_relu = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, dilation, groups,
                      False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(True))

    def forward(self, x):
        return self.conv_bn_relu(x)

class DeskewNet(nn.Module):
    def __init__(self, pretrained=None):
        super(DeskewNet, self).__init__()
        k=3
        self.backbone = nn.Sequential(
            ConvBnRelu(1,8,k,padding=k//2),
            nn.MaxPool2d(2,2), #256x256
            ConvBnRelu(8,16,k,padding=k//2),
            nn.MaxPool2d(2,2), #128x128
            ConvBnRelu(16,32,k,padding=k//2),
            nn.MaxPool2d(2,2), #64x64
            ConvBnRelu(32,64,k,padding=k//2),
            nn.MaxPool2d(2,2), #32x32
            ConvBnRelu(64,64,k,padding=k//2),
            nn.MaxPool2d(2,2), #16x16
        )
        self.avgpool = nn.AvgPool2d((16,16))
        self.fc = nn.Sequential(nn.Linear(64,64),nn.Linear(64,1))
        self.__init_weight()
        if pretrained:
            self.load_weight(pretrained)

    def forward(self, x):
        out = self.backbone(x)
        out = self.avgpool(out)
        out = self.fc(out.squeeze())
        return out

    def __init_weight(self):
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d):
                nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
            else: 
                pass

 

학습

학습 환경은 아래와 같이 구성했다.

optimizer: Adam

lr_scheduler: CosineAnnealingLR -> 초기 lr =0.001, eta_min 1e-6, T_Max=300

loss: MSELoss

batch: 128

아래는 학습 코드 메인 문이다. train, valid 함수 구현은 뻔하니 생략한다.

if __name__ == "__main__":
    args = parse_args()
    with open(args.config, 'r') as f:
        cfg = json.load(f)
        cfg['config_file']=args.config
        if args.run_name:
            cfg['mllogger_cfg']['run_name']=args.run_name
    
    tr = build_transformer(cfg['transform_cfg'])
    train_loader, valid_loader = build_dataloader(**cfg['dataset_cfg'], augment_fn=tr)

    logger.info('create model')
    model = build_model(**cfg['model_cfg'])#torch.hub.load('pytorch/vision:v0.10.0', cfg['model_cfg']['type'])
    model.cuda()
    logger.info('create loss function')
    fn_loss = torch.nn.MSELoss()

    logger.info('create optimizer')
    opt=torch.optim.Adam(model.parameters(), **cfg['optimizer_cfg']['args'])
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt,**cfg['lr_scheduler_cfg']['args'])

    max_epoch = cfg['train_cfg']['max_epoch']
    valid_ecpoh = cfg['train_cfg']['validation_every_n_epoch']
    logger.info(f'max_epoch :{max_epoch}')
    logger.info('set mlflow tracking')
    mltracker = MLLogger(cfg, logger)
    for step in range(max_epoch):
        train(model, train_loader, fn_loss, opt, mltracker, step)
        if (step+1)%valid_ecpoh==0:
            valid(model, valid_loader, fn_loss,  mltracker, step)
        lr_scheduler.step()

 

결과

아래는 train, validation loss를 mlflow tracking 으로 추적한 결과 이다. loss가 초반에 급격히 떨어지고 수렴은 하지만

validation이 중간 중간 불안정 하게 튀는 모습을 볼 수 있다.  이건 해결을 좀 해야겠다.

Fig5. loss graph

아래 Fig 6.의 왼쪽은 회전왼 입력 이미지이고 오른쪽 그림은 orientation을 추정해 correction 한 이미지 이다. 정확하진 않지만 생태는 많이 나아 졌다. 좀더 정확하게 correction 되도록 data검증 및 네트우크 수정을 해봐야겠다.

 

끝.

오늘 정리할 논문은 swin transformer로 요즘( 혹은 한동안) 핫한 transformer를 비전 테스크에 적용한 논문이다.
기존 CNN 기반의 backbone을 사용하지 않고 순수하게 transformer를 이용해 feature를 이미지에서 feature를 뽑아 낼 수 있다는 것을 보여준다.

 

Title :

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

 

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

This paper presents a new vision Transformer, called Swin Transformer, that capably serves as a general-purpose backbone for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains, such as

arxiv.org

git repo

 

GitHub - microsoft/Swin-Transformer: This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer u

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows". - GitHub - microsoft/Swin-Transformer: This is an official implementation...

github.com

 

Contribution

Transformer는 NLP분야에서 LSTM을 대체 할 수 있는 방식으로 연구되다 비전쪽에서도 다양한 문제에 응용되고 있다.
하지만 텍스트와 영상은 그 특성이 서로 다르다. 구체 적으로 영상에 transformer를 적용하는 것에는 다음과 같은 문제를 해결해야 한다고 이 논문에서는 말하고 있고, 각각에 대한 본인들만의 해결책을 재시 했다.

1. word token과 달리 visual element는 scale이 다양하다.
NLP를 정확히 몰라 이부분에 대한 내 이해가 정확한지는 모르겠지만 나름 해석을 해보자면 NLP에서 각 word token은 고정된 크기의 embedding으로 변환 된다. 하지만 이미지는 그 크기가(resolution이) 640x480, 1024X768등 다양할 수 있다. 따라서, 영상에서 NLP의 token과 같은 고정된 크기로 표현 가능한 단위를 설정하는게 필요하다.
-> 일정 한 수의 pixel 집합을 patch라고 정의하고 이 patch를 token처럼 처리의 최소 단위로 정의해 이 문제를 해결

2. 영상은 텍스트에 비해 high resolution이다.
semantic segmentation 같은 경우 pixel 단위의 prediction이 필요 한데 transformer의 computational complexity가 image size에 qudratic 하게 증가 하기 때문에 연산량 증가의 문제가 발생한다.
-> 계층적인 feature map을 이용하고 feature map의 window 내에서 local self-attention을 적용함으로서 complexity 문제를 해결, 각 window에는 일정한 수의 patch 만 포함되도록 설계하여 전체 연산량이 window 수에 선형적으로 증가하게 설계(self-attention 계산 방식의 최적화)

3. self-attention을 ResNet의 spatial convolution 전체 또는 일부를 대체하는 방식이 제안된 적이 있으나 실제 하드웨어에서 latency issue가 발생한다( 이부분은 정확히 이해 하지 못했다. 언급된 방식들의 caching 능력이 떨어진다고 봐야 할 거 같은데... 혹 이 글을 읽고 있는 누군가 이에 대한 답을 안다면 댓글로 설명 좀 해주심이..)
-> shift windows를 사용해 해결

즉 swin transformer 의 key feature는 1. shift window, hierahical patch representation(계층적 패치 표현), 최적화된 self-attention 계산 방식 으로 볼 수 있을거 같다.


구조


이제 swin transformer의 구조를 살펴 보자. 여기서는 아래 그림 Fig1 의 (d)의 전체 구조를 참고해 입력 영상이 각 모듈에 들어가서 연산을 거칠때 어떻게 변해 가는지 그리고 위의 contribution에서 언급한 각 문제점의 해결책이 어느 모듈에서 어떻게, 왜 수행되는지를 정리 할 것이다. 여기서는 숲을 설명하고 나무를 설명하는 방식이 아니라 각 모듈을 나무로 생각하고 나무에 대해서 설명한다. 각 모듈에 대한 설명을 이해하고 Fig1 (d)의 보면 전체적으로 이해 하는데 도움이 될 거라고 생각한다.

Fig1. Swin transformer 구조

Fig1은 swin transformer 의 계측적 구조, shift window, transformer block 내부 구조, 전체 모델의 구조를 보여 준다. 우선 (d)의 전체 구조를 기준으로 살펴 보자.

swin transformer에서는 patch 라는 단위를 NLP의 token 처럼 사용한다. 이 patch를 이용해 고정된 크기의 embedding을 만들어 낸다. patch 라는 용어가 새로 나와 겁먹을 필요는 전혀 없다. 입력 영상에서 4x4 윈도우 내에 들어오는 pixel들을 concat해서 표현한 것이 patch이다.
예를 들어 입력 영상을 HxWx3, patch 크기를 4x4 pixel로 정하면, patch partition은 4x4 크기의 grid셀에서 그룹을 형성하는 pixel들을

R G B R G B ... G B

와 같이 concat해서 나타낸다. 이렇게 나타낸 patch는 4x4x3=48 의 크기를 지닌다. (4x4는 patch 크기 이고 3은 pixel의 channel 수 이다.) 이렇게 표현된 patch를 embedding으로 표현 하기 위해 linear layer를 이용해 연산한다.

아래 그림 Fig2는 Fig 1. (d)의 patch partition + linear embedding 모듈에서 입력 영상이 어떠한 형태로 변화 되는 지를 간단히 도식화 하고 있다. Fig2 의 가장 좌측 그림을 보면 굵은 선 안쪽에 얇은 선으로 4x4 필셀을 표현했다. 서로 다른 색으로 표혀한 것이 4x4 pixel 그룹이다. 이 4x4 픽셀들을 channel 축으로 concat하고 128channel의 embedding 을 만들기 위해 linear layer에 입력하고 swin transformer block에 입력하기 위해 spatial 축(영상의 가로와 세로) 를 HxW 으로 flatten 해주면 Fig 2의 가장 오른쪽같은 형태가 된다.

Fig 2. Patch partition + linear embedding 도식화


다음으로 swin transformer block에서는 multihead attention을 이용한 연산을 수행한다. 블럭 내에서 수행되는 연산은 Fig 1의 (c)를 보면 알 쉽게 알수 있다. (c)에서 W-MSA는 윈도우 내에서 수행되는 multi head self-attention을 의미한고 SW-MSA는 shifted window multi head self-attention을 의미한다.

swin transformer에서 self attention은 non-overlapped window내의 patch만을 이용해 수행된다. 여기서 윈도우란 patch의 집합으로 이해 하면된다. (patch는 인접 pixel의 집합, window는 인접 patch의 집합)

 W-MSA와 SW-MSA가 있는 이유는 input의 크기에 quadratic 하게 증가하는 computational complexity 문제를 해결하기 위해 도입한 window라는 개념 때문이다. 윗 문장에서 말했듯 swin transformer에서 정의하는 윈도우는 non-overlap이다. 즉 윈도우 내에 속한 patch 들 간의 연관성은 self-attention에서 고려 할수 있지만 서로 다른 윈도우에 속한 patch들 간의 연관성을 파악할 수 없다는 문제가 생긴다. 이를 해결 하기 위해 도입된게 SW-MSA(shifted window multi head self-attention)이다. Fig 1.의 (b)가 W-MSA와 SW-MSA에서 feature map을 어떻게 나누는지 보여준다. 우선 Fig 1. (b)의 왼쪽 그림은 사이즈가 4x4인 non-overlapped window로 feature map을 나눴을 때의 경우를 보여 준다(W-MSA에서 연산에 이용하는 window partition 방식이다.) Fig 1. (b)의 오른 쪽 그림은 SW-MSA모듈에서 사용하는 window partition 방식으로 W-MSA에서 서로 다른 window에 속해 있던 patch들이 같은 window로 묶이면서 상호간의 연관성(self-attention)을 고려 할 수 있게 된다.

 

 SW-MSA를 구현 할때 한 가지 중요한 아이디어가 들어가는데 바로 cyclic shifted window이다. window의 크기와 shifted window의 stride에 따라 영상의 가장자리에 patch들이 window를 가득 채우지 못 할 경우 이를 해결 하기 위해 feature map을 단순히 패딩 하면 연산량이 증가되는 손실을 감수해야 한다. 저자는 이러한 문제를 효율적으로 풀기 위해 feature map을 top-left 방향으로 shift 시켰다. 아래 그림 Fig 3은 이 방식을 도식화해 보여 준다.

Fig 3. cyclic shift&amp;amp;amp;amp;nbsp;

Fig 3 의 가장 왼쪽은 partition 된 feature map을 보여주는데 정 중앙에 위치한 M이 속한 윈도우만 실제 윈도우 크기에 부합하고 나머진 partition들은 윈도우의 크기보다 작게 된다. 각 partition을 모두 패딩해서 윈도우 크기인 4x4로 채우는 대신 Fig 3 cyclic shift 에 나타낸 것 처럼 A,B,C를 top-left 방향으로 회전(cyclic shift) 시키면 패딩을 하지 않아도(실제 코드에서는 물론 패딩도 들어간다. 하지만 패딩의 크기는 최소화 된다.) window 크기에 딱 맞는 partition을 만들 수 있다. 단 이렇게 하면 A,B,C는 원래의 feature map 에서 실제로는 서로 이웃 하지 않은 patch들과 같은 window에 속해 self-attention이 계산되는데 cyclic shift 하기 전의 feature map에서 서로 이웃 한 patch들간에만 self-attention이 계산 되도록 mask를 씌워준다. 이렇게 SW-MSA 연산이 끝나면 cyclic shift했던 것을 되돌려 준다. 

 

W-MSA, SW-MSA에서 또 한가지 언급 할 것은 Relative position bias이다. Attention module 에서는 $\frac{Q \cdot K_{T}}{\sqrtd} +B$ 와 같이 positional bias (B) 를 연산 과정에서 더 해준다. swin transformer에서는 relative position bias를 이용했는데 윈도우 크기를 M 이라 할때 한 윈도우의 각 축으로 방향으로 $ [-M+1, M-1]$의 상대 위치 offset을 정의하고 이 값들로 B를 구성해 positional bias로 이용했다. (transformer에서 positional bias는 매우 중요한 개념이다. 여기선 간단히만 언급 했지만 사용 이유와 의미를 꼭 파악하자. 스스로에게 하는 말이다. ) 

 

 아래 Fig 4. 은 Fig 1. (d)의 swin transformer block에서 입력이 data가 어떤 모양으로 변하고 어떻게 연산되는지를 간략히 도식화 한 그림이다. patch partition + linear embedding 단계에서 flatten되었던 입력 영상을 non-overlapped window로 분리하기 이해 공간 정보를 복원한다.(공간 정보를 복원한다는건 입력의 shape을 (batch size, HxW, embedding size)에서 (batch size, H,W, embedding size)로 reshape 하는 과정을 말한다.) 그 후 윈도우 크기로 분해된 feature map을 W-MSA또는 SW-MSA의 입력으로 넣어 attention 을 계산 해 준다. 어텐션 과정도 내 나름대로 도식화 했지만 내가 보기에만 좋은 그림 같기도 하다. 혹시 self-attention 연산 과정을 정확히 알지 못하는 독자는 여기 를 참조 하길 추천한다. (나에겐 정말 큰 도움이 되었다.)

Fig 4. swin transformer block 내의 연산 과정 및 데이터 변화

 

그런데 왜 이렇게 window 개념을 도입하면 computational complexity가 줄어드는 걸까?

MSA 연산은 크게 다음과 같은 단계로 구성된다.  -> 다음에 나오는 것은 각 단계에서 이루어지는 연산의 연산량이다.

1. input x와 weight $W_{Q}, W_{K}, W_{V}$를 이용해 $Q,K,V$ 계산

-> $Q = X \cdot W_{Q} => (hw \times C) \times (C \times C) = hw\times C^{2} $

    $K = X \cdot W_{K} => (hw \times C) \times (C \times C) = hw\times C^{2}$

    $V = X \cdot W_{V} => (hw \times C) \times (C \times C) = hw\times C^{2}$ 

2. $Q,K$ 를 이용해 attention score 구하기

-> $A = Q \cdot K^T => (hw\times C) \times (C \times hw) = (hw)^2C$

3. attention score와 $V$를 이용해 값 정재

-> $Z = A \cdot V =>(hw \times hw) \times (hw \times C) = hw\times C$

4. attention 적용된 output $Z$에 output weight $W_{z}$ 적용

-> $ out = Z \cdot W_{z} => (hw \times C) \times (C \times C) = hwC^2 $

 

이렇게 각 단계의 연산량을 다 더하면 MSA의 연산량은 $\Omega (MSA) = 4hwC^2 + 2(hw)^2C$ 이가 된다. 

그럼 W-MSA 는 어떨까?

위에서 언급한 4단계에서 2, 3단계의 연산량이 아래와 같이 바뀐다

2. $Q, K$ 를 이용해 attention score 구하기-> attention score를 윈도우 MxM에서 구하기

-> attention score를 구할때 고려하는 patch수는 hw가 아니라 MxM 즉 $M^2$

    $A = Q \cdot K^T = (M^2 \times C) \times (C \times M^2) = M^{4}C$

3. attention score와 $V$를 이용해 값 정재

-> $ Z = A \cdot V => (M^2 \times M^2) \times (M^2 \times C) = M^{4}C$

 

다만 윈도우가 $\lfloor \frac{h}{M} \rfloor \times \lfloor \frac{w}{M} \rfloor $ 개 있으므로 

$\frac{hw}{M^2} \times 2(M^4)C = 2M^2hwC$ 가 된다. 

따라서 window multi head self-attention을 적용하면 $\Omega (W-MSA) = 4hwC^2 + 2M^2hwC$ 가 되어 

$M^2 <= hw $ 일경우 연산량이 적어진다. 

 

마지막으로 계층적(hierachical feature)를 생성하기 위해서 patch merging 모듈에서는 patch의 숫자를 줄인다. 방식은 patch partition 에서 input의 RGB 채널은 concat한 것 처럼

각 stage의 출력을 patch mergin layer에서 2x2 grid 안에 들어오는 즉 서로 인접한 4개의 patch를 채널 축으로 concat 해준다. 

 

이걸 convolution의 receptive field관점으로 해석 하면 stage 1에서 4x4 가 receptive field이고 stage2dptjsms 8x8, stage 3 에서는 16x16, stage 4에서는 32x32 와 같이 볼 수 있다. Fig 1의 (a)는 이것을 도식화 한 것이다. 

 

 

 

additive Angular margin loss


Additive angular margin loss 는 metric learning에서 사용하는 loss로 ArcFace라는 논문에서 face recognition문제를 해결하기 위해 제안했다. 얼굴인식 문제를 풀기위해 제안되었을 뿐 similarity 문제를 푸는 대부분에 응용이 가능 하다.

목표는 triplet loss와 유사하게 (동작 방식은 전혀 다르다 목표만 유사하다) intra-loss는 감소 시켜 같은 클래스에 속한 입력들이 feature space에서 응집력을 가지게 학습시키고 inter-loss는 크게 해서 서로 다른 클래스에 속한 입력들은 feature space에서 구분가능한 만큼 떨어지게 학습시키는 것이다.
단 additive angular margin loss에서 의 거리는 Euclidian 거리가 아니라 loss의 이름에서 유추 가능 하듯 angular 거리 즉 각도의 차 이다.  (왜 각도 차 인지 아래  '개념' 섹션을 참조하면 이해에 도움이 될 것 이다.)

- intra-loss: 같은 클래스에 속한 입력의 feature 들 간의 loss
- inter-loss: 다른 클래스에 속한 입력의 feature 들 간의 loss

장점


softmax나 triplets loss 는 다음과 같은 문제를 가지고 있다.
1. softmax의 문제
- 분류해야하는 클래스의 개수가 증가 할 수록 fc 레이어의 아웃풋의 크기가 선형 증가 하게 되어있다.
- 학습된 feature는 Closed-set(폐쇄형) 분류 문제에는 충분할 수 있지만 Open-set(개방형) 얼굴 인식 문제에 충분한 분별력을 갖지 못한다.(이건 논문에 써있는데 얼굴 인식 문제를 직접 풀어 본적이 없어서 실제로 그런지는 잘 모르겠다. )
2. triplets loss의 문제
- triplets loss를 설명 할때 [주의사항]으로 써놓았는데 triplets은 (anchor, positive, negative) 3쌍의 sample이 필요하다. 입력의 class가 커지면 저 3쌍의 결합 개수는 폭발적으로 증가 할 수 있다.
- semi-hard triplets 를 선택하는게 굉장히 힘들고 비용 소모가 크다.

고로 장점은
triplets loss 처럼 semi-hard 한 샘플을 신중하게 선택할 필요가 없으며 softmax 보다 더 분별력있는 feature를 학습 할수 있게 해준다.
아래 그림은 softmax와 additive angular margin loss를 이용해 학습한 feature를 2D 공간에 그린것이다.
이렇게 보면 softmax보다 additive angular margin loss가 훨씬 decision boundary가 명확해 보인다.

Fig 1. (좌) softmax를 이용해 학습한 feature의 2상에서의 분포, (우) additive angular margin loss를 이용해 학습한 feature의 2d 상에서의 분포

연산

Fig 2. additive angular margin loss 


위 그림은 additive angular margin loss가 어떻게 동작하는지 보여준다.

classification 문제를 기준으로 설명 하자면 대부분의 DCNN(Deep Convolutional Neural Networks)모델의 마지막 레이어는 FC(fully connected) layer이다.
이 FC layer의 출력을 feature 라고 하고 기호로 $x$라 하자. (위 그림의 $x_i$가 바로 이 feature이다.)

$x \in R^d$  _($x$가 d차원 상에 있다면)_ 라 하고 $W \in R^{dxn}$ 이라 하자.  여기서 $n$ 은 클래스의 개수 이다. (아래 저차원에서의 예를 들어 다시 설명할 것이다.)
additive angular margin loss 의 입력으로 들어온 $x_i$는 normalization 된 후
normalization 된 W와 메트릭스 곱 $W^Tx$연산을 한다.
우리는 $W^Tx$ 연산을 $W$의 각 column에 해당하는 $w_j$ $j \in (1, ..., n)$ 벡터와 입력 $x$ 벡터의 내적 연산 즉 $$w_j \cdot x$$으로 해석 할 수 있다. ($w_j$ 가 $W$의 column 벡터 이므로 $R^d$ 상에 존재 하며 $x$와 차원이 같으므로 내적이 가능하다.)

내적연산이 $\vec{a} \cdot \vec{b}= |\vec{a}||\vec{b}|\cos\theta$ 임을 떠올려 보자.

우리의 상황에 적용하면 아래와 같이 된다.
$$\vec{w_i}\cdot\vec{x}= |\vec{w_j}||\vec{x}|\cos\theta$$
여기서 $w_i$, $x$는 normalization 된상태 이니 $|\vec{w_i}|=1$, $|/vec{x}|=1$ 이다. 즉, $w_j  \cdot x=\cos\theta$ 이다.

$w_j  \cdot x=\cos\theta$ 의 값을 $acos$에 대입하면 $\theta$를 구할 수 있다. 이 $\theta$에 $margin$을 더해 주고

scale $s$를 곱하면! 위 그림 Fig2 의 별표 펴진 $s*\cos(\theta+m)$ 이 구해진다. 

이렇게 구한 값을 원래의 softmax의 입력으로 넣으면 additive angular margin loss가 완성고 아래와 같이 정리 할 수 있다. 

$$L = -\frac{1}{N}\sum_{i=1}^{n} \log \frac{e^{s*\cos \theta_{y_{i}}}}{e^{s*\cos \theta_{y_{i}}}+\sum_{j=1,j \neq y_{i}}^{n}}$$

 

개념

additive angular margin loss 가 동작 하는 기본 개념은 각도 차이이다. 

고차원은 어려우니 저차원에서 예를 들어 보자.

입력을 아래와 같다고 하자. $$x=(1,2,3)$$, $$W_{3,2}=\begin{pmatrix} 1 & 4 \\ 2 & 5 \\ 3 &6 \end{pmatrix}$$ 

노멀라이제이션 하면 아래와 같이 된다. ($W$ 매트릭스는 column 축으로 normalization한다.)

$$ x=(\frac{1}{\sqrt{14}}, \frac{2}{\sqrt{14}}, \frac{3}{\sqrt{14}})$$

$$W_{3,2}=\begin{bmatrix} \frac{1}{\sqrt{14}}& \frac{4}{\sqrt{77}} \\\frac{2}{\sqrt{14}} &\frac{4}{\sqrt{77}} \\\frac{3}{\sqrt{14}} &\frac{4}{\sqrt{77}}\end{bmatrix}$$ 

 

두 입력을 선형 연산 하면 

$$W^{T}x= \begin{bmatrix} a &\\ b \end{bmatrix}$$

가 되고 a,b는 normalization 되었으니 그 크리 $|a|, |b|$ 는 1이다. (a,b 는 $W^Tx 의 결과를 나타내는 두 값이다. 수식을 다 쓰기 힘들어 저렇게 간단하게 표현했을 뿐이다.)

그때 $W$의 각 column vector (1,2,3), (4,5,6) 각각을 3차원 공간 상에서 2개 클래스의 feature의 centre 좌표로의 벡터로 본다면

입력의 feature인 $x$와 $W$의 columnd 벡터들의 내적은 각 클래스를 대표하는 feature 벡터와 입력의 feature가 이루는 각이 얼마나 작은가? 를 알아 보는 과정으로 해석 할 수 있다. 

 

각 클래스의 대표 벡터인 $W$의 column 벡터들과 $x$ 가 모두 normalization 되어있으니 그들의 magnitude는 모두 1로 같으므로

각도 (angular)의 차이만 보겠다는 뜻이다.

기하적으로 해석하면 magnitude가 모두 1이므로 각 클래스의 대표 feature 벡터 와 입력의 feature 벡터 모두 이 예에서는 3차원 공간에서 구의 표면에 위치하고 addtive angular margin loss 는 이 구의 표면 위에서 같은 클래스에 속하는 feature 벡터들이 구의 표면 상에서 서로 가까운 위치에 놓이도록 벡터간의 각도를 작게 한다는 개념이다. 

 

 

 

 

 

'Deeplearning > Loss' 카테고리의 다른 글

Generalized Focal Loss 리뷰.  (0) 2023.08.08
[Metric learning loss] Triplet loss 설명  (0) 2021.11.24

Metric learning loss

Metric learning 은 데이터간의 유사도를 잘 수치화 하는 거리 함수(metric fucntion)을 학습 하는 것이다.

Metric learning loss는 입력으로 부터 추출된 feature들 간 상대적 거리를 추정하기 위한 loss로 명시적인 목표 값이 주어지고 해당 목표값 추정을 목적으로 하는 cross entropyregression loss와 개념이 다르다.

Detection이나 segmentation에 흔히 쓰이지는 않지만 간단하게 말하면
같은 클래스에 속한 입력간의 거리는 가깝게 만들고 하고 다른 클래스에 속한 입력간의 거리는 최대화 할때 씌인다. instance segmentation에서 같은 클래스의 서로 다른 instance를 분리 할때 유사한 개념이 사용되기도 한다.

face identification, few shot learning,  recommendation  등에 사용할 수 있다.

널리 알려진 것중 이 포스트 에서 알아 볼건 triplets loss 이다.

 

triplets loss

Fig 1. Triplet loss의 개념


triplets loss는 $R^k$ 에서(k는 feature의 dimension)  같은 클래스 또는 특정한 기준에서 유사한 입력 사이의 거리를 가깝게 하고 서로 다른 클래스, 유사하지 않은 입력 사이의 거리를 멀게 만드는데 목적이 있다.  Fig 1. 에서 Anchor는 기준이 되는 입력이고
Negative는 입력과 다른 클래스 , Positive는 입력과 같은 클래스이다.

 

그림에서 알 수 있듯이 LEARNING 하고 난 후 Anchor,Positive sample사이의 거리는 가까워 졌고 Anchor, Negative sample 사이의 거리는 멀어 졌다.

 

입력이 이미지 일때를 기준으로 예를 들어 보자.

Fig 2. Triplet loss의 예


위 그림에서 입력을 얼굴 사진들을 $x_a$, $x_p$, $x_n$ 이라 했을 때
- $x_a$, $x_p$ 는 서로 같은 사람의 얼굴, 즉 같은 클래스
- $x_a$,$x_n$ 은 서로 다른 사람의 얼굴, 즉 다른 클래스
이다.


여기서 $x_p$, $x_n$을 $x_a$ 를 기준으로 같은 클래스, 다른 클래스로 구분을 했는데
이렇게 기준이 되는 입력을 anchor라고 표현하고 anchor 같은 클래스면 positive sample
anchor와 다른 클래스 이면 negative sample이라 한다.
(입력 $x$의 아래 첨자가 $a,p,n$인 이유가 anchor, positive, negative 의 앞 글자를 딴것이다. )

$d(x_i,x_j)$ 를 $R^k$ 에서의 $x_i$, $x_j$ 사이의 거리라 할때  triplet loss의 목적은 아래와 같다. 
- $d(x_a, x_p)$는 최소화
- $d(x_a, x_n)$는 최대화

이때 단순히 최대화 하는것은 목적이 불분명하니 특정한 기준 $\alpha$을 도입해
같은 클래스 간의 거리보다 서로 다른 클래스간의 거리가 $\alpha$ 만큼 크도록 수식을 구성하면 아래 식과 같다.
$$d(x_a, x_n) > d(x_a, x_p)+ \alpha     $$

위 수식은 $$ L(x_a, x_p, x_n) = d(x_a, x_p) -d(x_a, x_n) + \alpha $$ 으로 다시쓸 수 있는데
$d(x_a, x_n)$ 이 커져서 위 수식이 0이하가 되면 그로 부터 network가 배울 필요가 없으니
$$ max(0, d(x_a, x_p) -d(x_a, x_n) + \alpha)$$ 으로 사용한다.

 

triplet loss를 사용할때 주의 할점


이 로스는 $x_a, x_p, x_n$ 의 3쌍이 필요한데 이 3개의 쌍을 어떻게 선택하느냐에 따라 학습의 안정성과 성능이 달라질 수 있다.
$d(x_a, x_p)$ 가 너무 작은 쌍이 dataset의 대부분을 이루고 있다면 여기서 배우는 정보량이 너무 적을수 있고 $d(x_a, x_n)$가 이미 이미 $margin$ 역할을 하는 $\alpha$보다 크면 이런 샘플에서도 학습할 정보량이 부족하다.
특히 negative sample의 선택이 중요하다고 보는 경우가 많다.
negative sample에 따라
* Easy Triplets: $d(x_a, x_n) > d(x_a, x_p)+\ alpha$ 인 경우
* Hard Triplets: $d(x_a, x_n) < d(x_a, x_p) $ 인 경우
* Semi-Hard Triplets:$ d(x_a, x_p) < d(x_a, x_n) < d(x_a, x_p) +\alpha$ 인 경우
로 나누기도 한다.

거리를 계산하는 함수 $d(x_i, x_j)$는 Euclidian 등 원하는 방식을 선택하면 된다.

 

참조


pytorch에 torch.nn.TripletMarginLoss 가 있지만
$(x_a, x_p, x_n)$ 샘플을 자동으로 선택해주는 기능은 없는 것으로 보인다.  











 

'Deeplearning > Loss' 카테고리의 다른 글

Generalized Focal Loss 리뷰.  (0) 2023.08.08
[metric loss] additive angular margin loss  (0) 2021.11.26

+ Recent posts