새소식

딥러닝, 머신러닝/논문리뷰

FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning

  • -

https://www.youtube.com/watch?v=fG2Wf2f7UoA&t=1s 

발표자료

발표 영상을 참고하시면 좋을 것 같아 공유드립니다.

 

1. Introduction

  • 이미지 기반 augmentation과 같은 기존 방식보다 다양한 변환을 위해 dataset에 있는 다른 이미지의 특징에서 추출된 대표 프로토타입의 작은 세트에 대한 soft-attention을 통해 입력 이미지 특징을 개선하고 보강하는 방법을 학습하는 모듈을 제안한다.

  • 제안한 module은 feature에서 학습되고 수행되기 때문에 입력 이미지에 더 다양하고 추상적인 변환을 적용할 수 있다.
    (클래스별 프로토타입 추출 : k-means)
  • 이 모델에서는 consistency regularization을 할 때 기존의 이미지 기반 augmentation과 제안된 feature based augmentation을 같이 적용하면 큰 발전을 얻게 되는 걸 보여주고 있다.
  • 이를 통해 Semi Supervised Learning 방법 발전에 기여

 

2. Related Works

2-1. Consistency Regularization Methods : Semi-Supervised Learning에서의 최근 대부분의 SOTA 모델들이 쓰는 방법, 모델의 input에 augmentation을 해서 새로운 input을 만들었을 때, output(prediction)이 별로 변하지 않아야 된다고 가정하고 모델을 regularize하는 방법, 모델이 이미지의 질감 또는 기하학적 변화(geometric changes)가 있어도 영향을 많이 받지 않도록 정규화한다.

 

2-2. Image-Based Augmentation : 밝기, 색조, 선명도 또는 회전 등을 사용하거나 동일한 데이터의 다른 view(prediction)를 사용한다. 또는 두 이미지를 조합하는 mixup과 같은 방식을 사용하기도 한다. 

Mixup

 

3. Feature-Based Augmentation and Consistency

  • 기존의 image-based augmentation : 동일한 클래스 내부 또는 외부의 다른 instance에 대한 지식 필요, Mixup의 경우는 mixup이 두 인스턴스 사이에서만 작동하기 때문에 기존 augmentation 방식의 한계를 부분적으로만 해결한다.
  • Manifold Mixup은 feature space에서 mixup을 수행해서 첫번째 한계에 접근하지만 두 샘플의 단순한 convex combination으로 제한된다.
    convex combination : 주어진 점들을 서로 연결한 도형 안에 존재하는 지점들
  • feature space에서 clustering을 수행해서 각 클래스의 정보를 작은 prototype set으로 압축하고, 모든 클래스의 prototype에서 전파된 정보를 통해 이미지 feature를 refine하고, augment한다. 
  • feature refinement, augmentation은 대표적인 prototype에 대한 lightweight attention network를 통해 학습된다. 

Convex Combination(출처: wikipedia)

3.1. Prototype Selection

- 다른 클래스의 knowledge를 feature refinement, augmentation에 효율적으로 활용하기 위해 feature space에서 clustering하여 각 클래스의 정보를 간결하게 표현하도록 한다.

- Feature space에서 K-means clustering을 사용해서 각 클래스에 대한 프로토타입으로 pk cluster mean을 추출하여 사용한다. 하지만 두가지 문제가 있다. (p-k-means는 k-means에서 centroid를 더 효과적으로 추출하기 위한 클러스터 초기화 방법임)
- SSL에서 대부분의 이미지에는 label이 지정되어 있지 않고, 모든 label을 사용할 수 있더라도 K-means를 실행하기 전에 전체 dataset에서 모든 이미지의 feature를 추출하는 게 계산 비용이 많이 든다.

- 그래서 위의 그림과 같이 Training Loop의 모든 반복에서 네트워크에서 이미 생성된 feature fx_i 및 pseudo label ^y_i를 수집한다.  (pseudo label : unlabeled data의 target class, 가장 확률이 높게 예측된 class)

- Recording Loop에서 pseudo label 및 feature 쌍은 computation graph에서 분리되어 나중에 사용하기 위해 memory bank에 push된다. 

- 프로토타입은 전체 dataset을 검토할 때 모든 epoch에서 K-means에 의해 추출되고, feature refinement 및 augmentation module은 training loop에서 새로 추출된 것으로 프로토타입을 업데이트한다.

 

3.2. Learned Feature Augmentation

- 위에서 얻은 prototype set을 이용하여 soft attention을 통해 학습된 feature refinement 및 augmentation 적용

- 위의 그림처럼 3개의 Fully Connected layers로 구성되어 있다.

- prototype feature에 attention weights를 합산하고, residual connection을 통해 입력 이미지 feature로 feedback한다.

- input image에서 추출된 feature fx와 class의 prototype을 각각 미리 학습된 function을 통해 처리해서 embedding값을 만들고, dot product를 해서 softmax 함수로 처리한다. 

 

3.3. Consistency Regularization

- 선택된 prototype과 함께 학습된 AugF 모듈은 더 나은 표현을 위해 입력 feature를 개선할 수 있으므로 더 나은 pseudo label을 생성할 수 있다. 따라서 pg = Clf(gx)에 의해 refined feature fx에 대한 pseudo label인 pg를 계산한다. 

- feature based consistency loss : Lcon = H(pg, Clf(fx))로 계산

- 약하게 변형된 이미지 x와 강하게 변형된 복사본 x^를 생성하여, pseudo label은 pg=Clf(AugF(Enc(x)))로 약하게 증강된 이미지 x로 계산한다. 그리고 강력하게 증강된 데이터 ^x에서 두가지 consistency losses를 계산한다. (하나는 AugF가 적용된 것, 다른 하나는 적용되지 않은 것) 

 

3.4. Total Loss

- Unlabeled data에 관해서는 3.3의 두가지 Loss가 적용되고 labeled data(x, y)에 관해서는 아래의 Loss를 적용한다.

- Total Loss는 labeled data와 unlabeled data의 Loss를 합하여 적용한다.

 

4. Results

- CIFAR-100, mini-ImageNet의 4K label 데이터셋에서 다른 SSL SOTA 기술들보다 5%, 17% 향상됨

- (error rate) 다른 도메인에서 오는 unlabeled data를 포함하는 DomainNet 설정에서의 결과, unlabeled 데이터의 50%가 다른 도메인에서 오는(ru=50%) 설정에서 성능이 큰 차이로 떨어진다. 그렇지만 semi-supervised baseline에 비해 ru=0%일 때 16%, ru=50%일 때 12% 향상되고, supervised baseline에 비해서도 더 나은 결과를 보여준다.

다른 SSL 방법들과의 비교

- 다른 SSL 방법들에 비해 더 낮은 baseline에서 시작하고, 더 간단하지만(class distribution alignment 및 self-supervised loss 없음) 현재 SOTA인 ReMixMatch와 유사항 성과를 얻었다. 

 

5. Ablation Study

5.1. Lcon-f 및 Lcon-g의 효과는 무엇인가?

5.2. image based augmentation baseline에 비해 제안된 방법이 얼마나 개선되었는가?  

- image based augmentation의 경우 동일한 수의 labeled samples을 사용했을 때, FeatMatch를 사용했을때보다 더 안좋은 결과가 나오고, labeled sample의 수가 적을수록 개선이 많이 되고 있다.

- Lcon-g보다 Lcon-f가 더 중요한 역할을 한다는 걸 알 수 있지만 두 Loss가 모두 있는 경우에 최상의 결과를 얻는다. 

그림4

AugD (image-based augmentation), AugF(feature-based augmentation)

- (c)에서 AugF는 AugF로 변환된 feature와 유사한 feature를 갖는 이미지를 가져온 것

 

Analysis

- What augmentation does AugF learn? : 그림4를 보면 AugD는 local에서만 데이터를 변형할 수 있고, feature space에서는 더 효율적인 consistency regularization을 위해 더 강력한 augmentation을 생성하지 못하고 있다. 그림 4b를 보면 AugF는 실제로 기능을 augment하고 refine하는 방법을 학습하는 걸 볼 수 있다. 확대한 이미지를 봐도 AugF 방법이 feature space에 맞도록 효과적이고 더 강력한 augmentation을 하는 걸 볼 수 있다. 

- What other reason does AugF improve model performance? : AugF 모듈이 추출된 prototype에 의해 더 나은 representation을 위해 입력 이미지의 feature를 refine할 수 있고, 따라서 더 나은 pseudo label을 제공할 수 있기 때문이라고 가정함

AugF에 의해 refine된 더 나은 feature representation을 학습한다. 

- What does Aug do internally? : 그림5b에서 AugF 내부의 attention 메커니즘은 입력 이미지 feature와 동일한 클래스에 속하는 프로토타입에 attend하도록 학습한다.

그림5

6. Code

 

 

참고자료

https://seewoo5.tistory.com/8

 

Regularization With Stochastic Transformations and Perturbations for Deep Semi-Supervised Learning

이번 포스트에서는 (제가 알기로는) Consistency Regularization이 처음 소개된 논문인 Regularization With Stochastic Transformations and Perturbations for Deep Semi-Supervised Learning이라는 논문에 대해..

seewoo5.tistory.com

https://light-tree.tistory.com/176

 

Convex 의 정의와 설명

Convex set, convex combination, convex hull, convex function 이 각각 무엇인지 살펴보겠습니다. Convex set 이란? Convex set(좌) 과 Non convex set(우) 이미지 출처:en.wikipedia.org/wiki/Convex_set 어..

light-tree.tistory.com

https://github.com/GT-RIPL/FeatMatch

 

GitHub - GT-RIPL/FeatMatch: PyTorch code for the paper: "FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning"

PyTorch code for the paper: "FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning" - GitHub - GT-RIPL/FeatMatch: PyTorch code for the paper: "FeatMatch: Feature-Based A...

github.com

https://arxiv.org/abs/2007.08505

 

FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning

Recent state-of-the-art semi-supervised learning (SSL) methods use a combination of image-based transformations and consistency regularization as core components. Such methods, however, are limited to simple transformations such as traditional data augment

arxiv.org

 

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.