새소식

딥러닝, 머신러닝/논문리뷰

Distilling the Knowledge in a Neural Network

  • -

모델의 구조가 정말로 중요할까?

그동안의 SOTA 논문들을 보면, Skip-Connection, Transformer, Graph Convolution 또는 시공간 통합형 구조와 같은 네트워크 구조의 변형을 통해 성능 향상을 시킨 경우가 많다. 또는 여러 모델의 결과를 합쳐서 사용하는 앙상블 방법도 대회에서 성능 향상을 위해 많이 사용되고 있다. 그렇지만 앙상블은 여러 모델이 합쳐진 만큼 컴퓨팅 비용이 더 많이 들어가고, 같은 자원이라면 추론 속도는 더 늦어질 수밖에 없다.

이들의 개념 모두 좋은 걸 알겠지만 이게 딥러닝에서 절대적으로 중요한 것인가에 대한 의문점을 해결해준 논문이 아닌가 싶다.

 

이 논문은 경량화에 관해 다루고 있는데, 기존의 복잡한 모델이 가진 지식을 좀 더 단순화된 모델에 전달하는 것(knowledge distillation)으로 원래의 모델에 근접하거나 그 이상의 성능을 낼 수 있다는 걸 보여주고 있다. 

 

Distillation

distillation

  • Distillation(증류) : 불순물이 섞여 있는 혼합물에서 원하는 성분을 분리시키는 방법
  • Knowledge Distillation (본 논문에서 제안한 방법) : 불필요하게 많은 파라미터가 사용되는 기존의 모델들(불순물이 섞인 혼합물)로부터, 보다 단순화된 모델에 지식(Knowledge)을 전달해서 핵심 부분만 살리고, 불필요한 부분 제거로 인해 모델의 속도도 개선하는 방법

 

Softmax

 

hard label과 soft label

보통 모델 학습을 진행할 때, 위 그림의 hard label([1, 0, 0])과 같은 형태로 정답 데이터를 가지고 진행한다.

그러나 실제 학습한 모델로 추론해보면 그 모델의 결과값(logit)을 softmax함수로 보내서 나온 값은 soft label(또는 soft target)로 [0.9, 0.0999, 0.0001]과 같은 형태로 출력된다(총합 1).여러 클래스의 데이터를 살펴본 모델의 입장에서 봤을 때, "이 사진은 90% 정도는 고양이 같은데, 약간 9.99%정도는 개같기도 하고, 털같은 질감이 미세하게(0.01%) 소같기도 하다" 라고 판단이 된 것이다.
이런 점에서 봤을 때, 기존의 정답데이터(또는 hard label, hard target)보다 soft label이 갖는 정보가 더 크므로 이 정보를 보내주는 게 학습에 유리할 것이라고 판단하고 연구를 진행한 것이다. 상대적으로 사이즈가 더 작은 모델을 학습시킬 때, "이 사진은 고양이야"라고 알려주는 것과 "이 사진은 거의 고양이인데, 약간 개같기도 하고, 아주 조금은 소 느낌도 있어"라고 알려주는 것은 확실히 다르다(뭔가 방법 자체를 알려주는 느낌).실제로 본 논문의 결과를 봤을 때, 같은 아키텍처의 모델을 정답 데이터만으로 학습한 경우보다 soft label로 학습한 결과가 더 좋았다.

 

하지만 여기서 문제가 있다. soft label을 알려줄 때 위 그림의 경우 "소"라는 클래스의 값은 매우 크기가 작아서 정보가 잘 전달되지 않을 수 있기 때문에 약간의 변형을 취해준다.

온도(Temperature)가 적용된 softmax

기존의 softmax 수식에 파라미터 T(온도, Temperature)를 추가해서 작은 값도 좀 더 잘 반영이 되도록 했다. T=1인 경우 기존의 softmax와 같다고 보면 된다. 논문에서는 T 값이 2~4 정도일 때 distillation이 효과적으로 적용되었다고 하고 있다.

 

Cross entropy

  • student model의 logit인 z와 teacher model의 logit인 v의 soft label을 cross entropy로 처리함

 

식만 봐서는 잘 이해가 가지 않아서 코드를 구현해 보았다.

https://github.com/woo1/KD_Study/blob/main/softmax.ipynb

T값에 따른 softmax 값의 변화

[0.3, 2.9, 4.0]이라는 logit 값(softmax로 보내기 전 모델의 최종값)이 나왔다고 했을 때, 원래의 softmax(빨간색)와 T값을 2(초록색),3(파란색)으로 줬을 때의 결과이다.

그림을 보면, T값이 늘어날 수록 원래 컸던 값들은 상대적으로 작아지고, 원래 작았던 값들을 상대적으로 커지고 있다. 이 방식으로 논문에서는 기존 softmax의 output을 더 soft하게 만들어서, 좀 더 많은 정보가 전달되도록 하고 있다.

 

Experiments

1. MNIST

MNIST Dataset

  아키텍처 결과 (test error)
baseline 784 → 800 → 800 → 10 146
baseline+ 784 → 1200 → 1200 → 10 67
knowledge distillation 784 → 800 → 800 → 10 74 (↓72)
  • baseline+ : baseline보다 높은 성능의 모델을 만들기 위해 파라미터 수를 더 늘리고, dropout과 input jittering 옵션을 추가한 모델
  • knowledge distillation 적용을 위해 학습된 baseline+ 모델로 original training data에 soft label을 생성하고 원래의 hard label과 함께 활용하여 아키텍처가 baseline과 동일한 구조로 학습했더니, 기존 baseline보다 성능이 많이 개선됨

 

3이 빠진 MNIST

한 클래스 데이터를 아예 빼버리면 어떻게 될까?

논문에서 기존 모델의 지식을 가지고 있다면, 데이터를 본 적 없어도 그게 3인줄 알 수 있을지를 실험해봤습니다.

knowledge distillation 방식으로 학습한 결과, 숫자 "3"에 해당하는 데이터를 아예 보여주지 않았는데도 test error가 109이고(baseline은 146), test set에 있는 "3"(총 1010개)중 1.4%(14개)를 제외하고 다 맞췄다고 합니다.

 

모델은 학습 도중에 "3"이라는 이미지를 한번도 본 적이 없지만, soft label을 통해 배운 것이죠. 지금 이 숫자 "2"라는 이미지는 "3"이라는 이미지와 5% 비슷하고, 숫자 "9"라는 이미지는 "3"이랑 8% 비슷하다던데, 그럼 이게 3일까? 할 수 있게 학습된 것이죠.

 

80% 클래스를 다 지워버리면?

7, 8만 남은 MNIST

앞의 실험에 이어 추가적으로 대부분의 카테고리를 없애고 학습을 진행합니다. 7, 8 데이터만 남기고 진행했는데도 모든 카테고리에 대해 87%의 test accuracy가 나왔다고 합니다. 

 

MNIST라서 되는 거 아니야? Speech Recognition은?

8개의 hidden layer, 14000개의 label이 있는 softmax layer가 있는 모델을 baseline으로 하여 파라미터 초기화를 다 다르게 Random하게 한 10개 모델을 앙상블한 결과와 1개 모델의 baseline, distillation 결과입니다.

약 2000시간의 영어 음성 데이터를 사용하였다고 합니다. 나름 사이즈가 큰 데이터셋에서도 이 논문의 방법이 효과가 있다는 게 입증되었습니다.

 

본 논문의 설명으로 볼 때 10개의 모델을 별도로 학습하였다고 하니까 10개 모델을 따로 학습하고 각 모델의 soft label을 수집해서 distillation에 활용하지 않았을까 싶습니다. 이렇게 하면 학습 리소스 면에서도 별 부담없이 앙상블의 효과를 거둘 만큼 할 수 있어서 좋을 것 같습니다.

 

그럼 엄청 큰 데이터셋에서도 이게 효과가 있을까? - JFT Dataset

  • JFD Dataset : 15,000개의 레이블, 1억개의 레이블이 지정된(전처리된) 이미지가 있는 데이터셋
  • baseline : 약 6개월간 학습된 대형 CNN 모델

dataset이 크고, 학습시간이 너무 오래 걸려서 single dustbin class 방식을 사용해서 나중에 앙상블하는 형식으로 했다는데 이게 binary classification을 하는 모델들로 여러 개 만들었다는 건지 확실치가 않네요.

 

Regularizers로서의 Soft Targets

앞의 Speech Recognition 모델의 데이터를 3%만 사용했을 때의 결과

  • MNIST에서는 적은 데이터셋만으로 좋은 결과가 나왔는데 음성 데이터에서도 되는지를 실험한 것 같습니다.
  • 전체 중 3%(약 2000만개)로 학습을 진행했을 때, hard label만으로 학습을 진행하면 Train Accuracy는 baseline보다 더 올라가고(67.3%), Test accuracy는 훨씬 떨어지는(44.5%) 결과가 나왔습니다.
  • 반면, 같은 3% 데이터라도 soft label을 사용했을 때는 Train Accuracy는 조금 올라가고(65.4%), Test Accuracy는 거의 차이가 없는(57.0%) 결과가 나왔습니다.
  • 이는 Soft target이 regularizer의 역할을 할 수 있고, 3%밖에 안되는 상대적으로 적은 데이터도 충분히 활용 가치를 낼 수 있도록 해주는 것을 보여주고 있습니다.

 

코드

Student, Teacher 모델 (위 논문설명의 baseline+, baseline)
loss

코드 링크 : https://colab.research.google.com/drive/1oweLs9ttGEYVIw5H8EmmXggL8SIoZfGK?usp=sharing 

두 분포가 얼마나 다른지를 측정하는 KLD Loss와 student prediction, hard label과의 차이를 구하는 cross entropy loss를 같이 사용하고 있습니다. 

해당 코드에서는 T=20으로 두고 alpha 값을 0.7로 둬서 hard label은 30%만 반영하도록 하고 있습니다. (soft label에는 T의 제곱 + alpha로 코드를 짜놔서 위의 경우는 400:0.3으로 hard label은 실질적으로 0.1% 정도 반영되고 있습니다.)
Student 모델 학습 결과(distillation) - 100 epoch로 Teacher 모델과 epoch 수가 다르긴 한데 30 epoch 일 때도 거의 차이가 없다. accuracy 0.01~0.02 차이 수준
Teacher 모델 학습 결과

결과적으로 간단한 설정으로 2%의 성능 향상이 나타났다.

왼쪽 : Teacher model loss, 오른쪽 : Student model loss

특이한 점은 이미 학습된 teacher 모델의 soft label을 사용한 것 때문인지 수렴이 빨리 되는 모습이 나타난다.

 

코드는 딥러닝 공부방 블로그(https://deep-learning-study.tistory.com/700)를 참조하였습니다


참고자료

https://blog.lunit.io/2018/03/22/distilling-the-knowledge-in-a-neural-network-nips-2014-workshop/?blogsub=confirming#subscribe-blog 

https://arxiv.org/pdf/1503.02531.pdf

 

 

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.