Continual Segmentation with Disentangled Objectness Learning and Class Recognition
https://arxiv.org/pdf/2403.03477
0. Abstract
objectness를 이용한 query기반의 segmenter는 pixel단위의 segmenter와 비교했을 때 고유한 장점을 가지고 있다. objectness는 강한 전이능력와 forgetting resistance를 가진다는 것이다. 이를 기반으로 2stage(forgetting-resistant continual
objectness learning와 well-researched continual classification)의 continual segmentation CoMasTRe를 제안한다.
1. Introduction
Neural network는 object recognition부터 segmentation, detection 등 computer vision의 많은 부분을 차지했다. 하지만 이는 마지막 학습과정을 고려하지 않고 완전히 annotate되어있는 dataset을 이용하여 단 한번의 학습이 잘되도록 한다. Continual learning은 지속적으로 지식을 얻는 사람의 학습방법을 모방하는 것을 목표로 한다. 특히, continual learning은 semantic segmentation같은 dense prediction에서 빛을 발휘하는데, annotation이 노동집약적이고 자연스럽게 얻은 데이터들은 부정확하기 때문이다.
그러나, 기계는 새로운 것을 학습시키면 일찍 배웠던 task를 까먹는 continual learning에 치명적인 catastrophic forgetting이 발생한다. 더군다나, 각 픽셀 주변의 pattern이 다른 것을 분석함으로써 dense prediction task은 픽셀 단위의 classification이나 regression을 수행한다. continual segmentation의 forgetting을 어떻게 하면 더 유연하게 해결할 수 있을까? mask classification을 이용해 query기반의 class-agnostic binary mask와 class recognition을 forgetting 없이 continual learning을 수행한다. pixel 단위의 segmenter과 다르게 query 기반의 segmenter는 강한 objectness를 가지고 있어, continual segmentation에 좋은 2가지 이유를 가진다. 첫째, mask를 제안할 때 background가 완전히 배제되지 않기 때문에 unseen classd에 대해 objectness가 전이될 수 있다. (픽셀 단위일 때는 배경, 객체 두가지를 완전하게 분리하여 학습하고 쿼리 단위일 때는 두가지를 함께 고려하기 때문에 배경 내에 객체에 대한 정보를 모델이 학습할 수 있다.)

Figure 1(a)와 같이 pixel단위의 DeepLabc3와 query단위의 Mask2Former를 비교해보면, vehicle, plane, buse 같은 비슷한 class를 학습했을 때 새로운 class인 train을 query 단위의 Mask2Former가 mask를 더 잘 만드는 것을 볼 수 있다. 둘째, Figure 1(b)를 보면, old class에 대한 forgetting에 강인하다는 것을 확인할 수 있다.

이를 기반으로 Continual Learning with Mask-Then-Recognize Transformer인 CoMasTRe를 제안한다. 해당모델은 첫번째 단계에서 class-agnostic mask를 제안하고, 두번째 단계에서 recognition을 함으로써 objectness learnin과 class recognition으로 segmentation task를 나눈다. query 기반의 방법으로, 새로운 mask 제안이 간단하다. 또한, forgetting에 강한 continual objectness learning, continula classification을 통해 continual segmentation이 간단하다. 2가지 방법으로 CoMasTRe는 forgetting을 완화한다. 첫째, 긴 학습과정동안 old class objectness를 강화하기 위한 간단하지만 효과적인 objectness distillation 둘째, old class forgetting을 완화하기 위한 특정 classifier와 multi label class distillation.
PASCAL VOC 2012, ADE20K 데이터셋에 대해 CoMasTRe는 SOTA성능을 달성하였다.
Contribution은 다음과 같다.
- objectness를 활용하기 위해 Continual Learning with Mask-Then-Recognize Transformer decoder인 ComasTRe를 제안하며, objectness learning과 class recognition을 활용하여 continual segmentation을 쉽게 수행할 수 있다.
- forgetting 문제를 해결하기 위해 distillation을 활용하여 objectness를 강화하고, classifier와 multi label class distillation을 통해 class knowledge를 보존한다.
- ADE20K, PASCAL VOC 2가지 데이터셋에서 SOTA 성능을 달성하였다.
2. Related Work
Query-based Image Segmentation
Query기반의 segmenter는 mask classification을 통해 같은 framework로 semantic, instance, panoptic segmentation 문제를 해결해왔다. mask classification은 DETR과 같은 query기반의 detector로 시작되었으며, 이는 proposal(mask, 영역 같은), class label을 예측하도록 훈련한다. MaskFormer는 pixel단위의 CNN이나 transformer 를 활용한 최초의 mask classification이다. MaskFormer는 multiscale feature의 작은 단위 수행에 사용된다.
이 segmenter들은 mask와 class를 동시에 예측하지만 continual segmentation에는 적합하지 않다.
Continual segmentation
model을 finetuning함으로써 새로운 knowledge를 학습한다는 것은 기본적인 방식이지만 old knowldege를 catastrophic forgetting을 유발한다. Continual learning에서는 new knowldege를 잘 학습하고(plasticity), old knowldege를 보존하는 것(stability)의 균형을 잘 맞추는 것이 목표이다. background를 바꾸는 것은 forgetting 문제를 악화시킨다. pixel 단위의 기존 방법들은 distillation, pseudo-labeling을 통해 문제를 해결했다. CoMFormer는 최초로 mask classification 방법을 제안하였지만 objectness를 활용하지 못했다. CoMasTRe는 objectness를 활용하여 objectness learning과 class recognition을 나누어 forgetting문제에 강인하게 만들었다.
Continual Dynamic Networks
distillation방법과 더불어 parameter를 확장하는 dynamic network 또한 continual learning에서 핵심방식으로 자리잡고 있다. 최근에는 transformer기반의 dynamic nature을 활용한 방식들(task 특화 prompt, token확장)도 제안되고 있다. 본 논문에서는 task query와 class 특화 classifier를 통해 forgetting문제를 해결한다.
3. Method
3.1. Problem definition
image segmentation은 mask classification문제로 이어져왔다. class agnostic mask를 제안하고 동시에 class label을 예측하는 방식이다. dataset D에 image(x = C * H * W = 채널 * 높이 * 너비)와 target(y) 쌍을 포함한다. 각 target은 GT M (mask, class label)로 구성된다.


pixel기반의 classification과 달리 background class는 annotation에서 배제한다.
continual segmentation에서는 new class mask를 예측하고 old class를 잊지 않도록 학습한다. 본 논문에서는 T task로 학습과정을 나누고 각 step t = 1, 2, 3 ,... , T이다. 각 시간마다 unique한 class만 학습하게 된다. 즉, 각 t시간마다 겹치는 class는 없고 모든 t시간을 합쳤을 때 전체 class C가 되어야 한다. t에서는 새롭게 학습할 class만 annotation되어 있고, model은 모든 class를 학습해야 한다.
(t=1일 때 C1={cat, dog}에 대한 annotation만 포함, t=2일 때 C2={bird, fish}에 대한 annotation만 포함. 하지만 t=2일 때 C={cat, dog, bird, fish} 모두를 예측할 수 있어야 한다. 새로운 class에 대한 annotation만 주어지는 이유는 continual learning을 할 때 기존의 class를 반복학습하지 않는 경우가 많기 때문에 새로운 class에 대해서만 학습한다.)
3.2. CoMasTRe Architecture
저자들은 objectness와 continual classification연구들을 통해, mask classification을 활용한 continual segmentation문제를 해결하는 것이 유리하다고 주장한다. 따라서 Mask2Former를 활용하면서, objectness learning과 class recognition을 수행할 수 있는 2개의 새롭게 설계된 transformer decoder를 추가한 CoMasTRe를 제안한다.

전반적인 학습과정은 다음과 같다.
1. image가 입력되면, backbone에 들어가 pixel decoder를 통해 pixel 임베딩을 생성한다. pixel decoder의 4개 레이어에서 각각 pixel 임베딩을 추출한다.

2. 무작위로 초기화된 qeuery를 objectness learning의 입력으로 사용한다. positional embedding을 추출하기 위해 mask decoder f를 사용하여 class agnostic mask proposal M과 objectness score S를 예측한다. GT와 Bipartite Mathcing을 이용하여 학습한다.

3. positional embedding과 pixel embedding을 결합해 class decoder의 입력으로 사용하여 recognition을 수행한다. task interference를 줄이기 위해, task query와 task specifice classifier는 class knowledge에 특화되도록 학습한다.

4. 마지막으로 stage 1의 mask proposal과 objectness score과 stage 2의 class prediction을 결합하여 segmentation result를 얻게 된다.
3.2.1 Stage 1: Objectness Learning
Fig. 2(b)와 같이 objectness learning은 mask decoder에 의존한다. mask decoder는 Transformer layer의 L block으로 이루어져 있으며, N개의 positional queryh와 중간 pixel embedding을 입력으로 사용한다. 그 후 positional embedding을 통해 mask proposal과 objectness socre가 출력되고, stage 2 recognition을 진행한다. mask proposal M은 아래와 같이 계산된다.

각 positional embedding은 mask proposal을 생성한다. 동시에 positional embedding은 objectness head에 입력되며, binary classifier는 objectness score를 출력하며, 이는 mask proposal이 object을 포함하는지 포함하지 않는지를 나타낸다. 학습동안 bipartite matching을 수행하여 N개의 mask proposal과 M개의 GT간의 cost를 계산한다. N >> M이라고 가정하고 예측과 GT간의 1:1 매칭을 보장하기 위해 no object인 GT pad를 추가한다. Hungarian 알고리즘을 통해 matching을 수행한 후 N개 요소의 최적 순열을 획득해 σ로 지정한다.

Loss mask는 cross entropy loss와 Dice loss의 합을 나타낸다. matching결과를 기반으로 M개의 positional embedding이 match되며, stage 2 class decoder로 전달된다. matching된 것과 matching되지 않은 것을 활용하여 objectness score를 학습한다. 그렇지 않으면, 해당 모델은 모든 proposal에 대해 높은 objectness score를 가지게 된다. matching되지 않은 embedding은 distillation을 통해 forgetting을 완화시킨다.
3.2.2 Stage 2: Class Recognition
Stage 2는 mask proposal의 object를 인식하는 것이 목표이다. 이를 위해 pixel embedding과 매칭된 positional embedding을 입력으로 사용한다. 또한 task query를 함께 입력하여 task specific classifier를 수행한다. 이를 통해 task interference를 감소시켜 더욱 특화된 classifier를 만들어 continual learning을 잘 수행할 수 있게 한다.
continual learning은 T개의 task를 포함하기 때문에 현재 learning step을 t라고 가정한다. 각 positional embedding과 task query는 한쌍으로 취급하며 pixel embedding과 함께 class decoder에 입력되어 task embedding을 생성한다. t개의 task specific classifier는 task embedding에 적용된다. 각 proposal의 class probability를 획득하여, mask proposal과 positional embedding을 매칭한다. 모든 positional embedding을 입력함으로써 모든 proposal의 class probability를 얻을 수 있다. GT class label을 사용하여 classification loss와 stage를 학습한다.

calibration을 위해 cross entropy loss가 아닌 focal loss를 사용했으며, 이는 continual learning동안 distillation 과정을 더 잘 수행할 수 있도록 한다. 전반적인 segmentation loss는 objectness loss와 classification loss의 합이다. 2단계를 segmentation loss가 최소화되도록 학습한다. inference단계에서, objectness 임계값을 설정하여 low confident prediction은 거르고 높은 objectness embedding은 stage 2로 넘어갈 수 있도록 한다.
3.3. Learning without Forgetting with CoMasTRe
새로운 class를 학습하는 것은 catastrophic forgetting을 유발한다. 이를 막기 위해 CoMasTRe는 두단계 모두에서 distillation을 고려한다.

3.3.1 Objectness Distillation

Fig. 3(a)와 같이, bipartite matching을 통해 학습하고, matching된 positional embedding과 matching되지 않는 positional embedding을 얻게 된다. 하지만, matching되지 않는 embedding의 objectness score는 빠르게 줄어든다. GT가 old class의 mask를 포함하지 않고 그 결과 old class의 objectness score가 정확하지 않기 때문이다. 따라서, 이전단계의 objectness score를 유지하기 위한 objectness score distilllation을 제안한다. 또한 objectness를 더 향상시키기 위해 mask proposal, positional embedding에 대한 2가지의 distillation loss를 제안한다.

old class의 forgetting을 완화하기 위해 마지막 단계의 출력을 distill한다. distillaion과정에는 마지막 단계의 mask proposal, objectness score, positional embedding이 필요하다.
Distill objectness socres.
현재 matching되지 않은 objectness score와 마지막 단계의 objectness socre간의 Kullback-Leibler(KL) divergence를 최소화함으로써 object ness score가 줄어드는 것을 완화하기 위해 vanilla knowledge distillation loss를 사용한다.

Distill mask proposals.
mask proposal의 knowledge를 유지하면서 높은 objectness score를 맞추기 위해, objectness score로 mask distillation을 재계산한다. 낮은 objectness score를 distill함으로써, objectness score가 objectness를 잘못 인식했을 경우를 완화시킨다. 또한 일부 mask proposal은 objectness score가 낮은 unseen class에 대해 일반화 성능을 개선한다.

Distill positional embeddings.
mask distillation과 유사하게, matching되지 않는 positional embedding을 이전 단계의 positional embedding과 유사하게 하여 position distillation도 재계산한다. 현재의 단계와 이전 단계의 높은 cosine similarity를 가지게 하여 multi label class distillation의 positional 정보를 보존할 수 있도록 한다. (unmatched positional embedding에 집중한 이유는 새로운 knowledge를 학습함을써 forgetting이 일어나는 경우가 많다. 따라서 이를 방지하기 위해 unmatched에 집중한다.)

최종적으로 아래의 loss가 최소화되도록 objectness를 distill한다.

3.3.2. Class Distillation
class distillation을 두단계로 나눌 수 있다. (1) GT와 matching되지 않는 이전 단계의 높은 objectness embedding의 class 예측을 distilling (2) matching되는 embedding의 class knowledge를 distilling

Distill fron unmatched queries.
현재 단계와 이전단계의 class probability간의 KL divergence를 최소화하여 old class의 pseudo labeling을 비슷하게 하는 방법이다. 이전단계의 matching되지 않는 embedding이 주어지면, 임계값을 초과하는 n개의 높은 objectness socre를 가지는 embedding을 선택한다. 현재의 embedding 또한 선택한다. class decoding 후 이전 단계의 class logit과 현재 단계의 class logit을 획득한다. 두 단계의 KL divergence를 계산하여 knowledge distillation loss를 줄이고 old class의 robability를 강화한다.
Distill from matched queries.
새로운 sample에 대해 forgetting을 완화하며, 이전 단계의 mopdel 예측을 고려한다. knowldege를 추출하기 위해, 현재의 matching되는 embedding을 재사용하고, old class probability를 획득한다. 또한, 이전 단계에서 얻은 probability를 target distribution으로 사용한다. 학습동안 class distillation loss를 최소화한다.
loss를 최소화함으로써 class knowldege를 보존한다. 또한 auxiliary loss를 사용하여 유사한 class에 대해 더욱 dicsriminative classification을 수행한다.
4. Experiments
4.1. Setup
Datasets.
PASCAL VOC 2012(20개의 object, background class), ADE20K(150개의 class)를 사용
Continual Segmentation Protocols.
Continual segmentation에서는 sequential, disjoint, overlapped 방식이 있다. sequential은 가장 쉬운 방법으로 각 단계에서 모든 class의 pixel별 GT를 포함한다. disjoint방식은 이전의 class를 background로 사용하고 미래의 class는 학습에서 배제한다. overlapped는 가장 어려운 방식으로 dataset은 현재 class의 GT만 포함하고 past, future 모두 background 취급한다. 저자들은 overlapped 방식을 채택한다.
Metric.
mena intersection over union(mIoU)를 사용한다. base class C1과 incremented class C2:T를 측정한다. base class는 stability, incremented는 plasticity를 나타낸다. mIoU 평균은 전반적인 continual learning을 평가하고, joint는 모든 class를 동시에 학습했을 때의 성능을 나타낸다.
Implementation details.
benchmark는 ImagerNet으로 pretrain된 ResNet101을 사용한다. pixel decoder는 Mask2Fromer, layer 9개로 구성된 Transformer decoder block을 사용한다. positional query는 20~100개로 구성된다.
4.2. Quantitative Evaluation


4.3. Ablation Studies
Joint training results.

Objectness transfer ability analysis.

Effectiveness of objectness distillation.

Effectiveness of class distillation.

Effectiveness of stage 2 other components.

5. Conclusion
objectness learning과 class recognition으로 나눈 continual segmentation framework인 CoMasTre를 제안한다. objectness를 활용하여 2단계의 query기반의 segmenter를 제안하고, objectness와 classification knowledge를 각각 distill하여 forgetting을 완화한다. SOTA 달성을 하였다. CoMasTRe는 continual semantic segmentation에 활용될 수 있으며, panoptic, instance segmentation에도 활용될 수 있다.