SimCSE - Simple Constrastive Learning of Sentence Embeddings

Key Value
Paper https://arxiv.org/abs/2104.08821
publication Date 2021

Problem

  • 기존 방식
    • 기존 문방 embedding 방식은 문장간 의미적 유사도를 잘 반영하지 못했음
    • 따라서 기존 방식은 data augmentation또는 복잡한 학습 전략 없이는 성능 한계가 있었음
    • supervised 방식은 많은 labeled data가 필요했음
  • SimCSE
    • SimCSE는 Dropout만 간단하게 사용해, data augmentation을 수행함
    • constrastive learning을 통해서 문제를 해결

Summary of the Model

  • Unsupervised and Supervised (두가지 방식 모두 지원)
  • Unsupervised SimCSE
    • 단순히 input sentence 자신을 dropout한 이후에 자기 자신을 예측 하는 방식
      • 즉 동일한 문장을 pretrained model에 “두번” 넣은후 -> dropout -> 두개의 문장 embedding을 얻음
      • 이렇게 얻은 두개의 embeddings은 positive pairs 로 간주됨
      • 이후 동일한 mini batch내에서 다른 문장들은 “negatives” 로 간주함
      • representation collapse에 강함
        • 서로 다른 다양한 문장들이 동일한 embedding으로 수렴되는 현상
        • Dimensional Collapse: embedding vectors들이 전체 embedding space가 아니라 특정한 저차원 공간에 “집중”되는 현상
        • Representation Collapse: Vector Quantization 에서 주로 발생하는 현상으로, 특정한 클러스터 대표값으로 묶이면서, 다양한 패턴을 잃어버리는 현상
        • Neural Collapse: 분류 문제에서, embeddings들이 같은 클래스의 대표값 (평균값)으로 수렴하는 현상
  • Supervised SimCSE
    • NLI 데이터셋을 사용
      • 기존 3-way classification (entailment, neutral, contradiction) 사용 대신에, entrailment pairs 를 positive pairs로 사용하였음
      • 추가적으로 contradiction paris를 hard negative 로 사용했을때 성능이 더 향상되었음

Detailed Explanation

Unsupervised SimCSE

Unsupervised SimCSE는 label 없이도 문장 표현 학습이 가능한 방식으로, 아주 간단하면서도 효과적인 전략을 사용합니다.
방식은 다음과 같습니다

동일한 문장을 dropout을 두번 적용하여 embedidngs 을 생성 -> 이 둘을 positive pair로 간주하여 constrastive learning을 수행

1. 입력 문장 s
          ↓
2. 같은 sentence → encoder 통과 (dropout 적용)
          ↓
 ┌──────────────┐        ┌──────────────┐
 │  z1 = BERT(s)│        │  z2 = BERT(s)│
 │ (1st dropout)│        │ (2nd dropout)│
 └──────────────┘        └──────────────┘
          ↓                      ↓
       Positive Pair: (z1, z2)  ←──────┐
                                       │
                                       │
               Same sentence with different noise
                                       │
          ↓                            ↓
3. 다른 문장들의 embedding은 모두 negative로 간주
          ↓
4. InfoNCE loss로 contrastive 학습 수행

이렇게 하는 이유는 예를 들어 image라면 crop, flip, rotation, distortion등으로 augmentation을 수행할 수 있지만,
문장에서는 dropout이 가장 간단하면서도 효과적인 방법이기 때문입니다.

  • 동일한 문장을 두 번 인코딩하면 dropout 때문에 서로 조금씩 다른 embedding이 나옴
  • 이 두 벡터가 positive pair가 되고,
  • 같은 배치에 있는 다른 문장 embedding들은 전부 negative pair로 간주
  • 이렇게 batch 내부에서만 positive/negative를 구성하여 학습

Constrastive Learning

\[L = -log \frac{e^{sim(h_i, h^+_i) / \tau }}{ \sum^N_{j=1} e^{sim(h_i, h_j) / \tau}}\]
  • \(h_i, h^+_i\) : positive pair (dropout으로 생성된 두개의 문장 embedding)
  • \(sim(...)\) : cosine similarity \(\frac{h_i \cdot h^+_i}{\|h_i\| \cdot \|h^+_i\|}\)
  • \(\tau\) : temperature hyperparameter 보통 0~1 사이의 값이며 보통 0.05로 설정
    • 낮은값 (0.05): 빠른 수렴이지만, 오버피팅 위험 가능성
    • 높은값 (0.1): 느린 수렴이지만, 더 일반화된 표현을 학습할 수 있음

cosine similarity 계산하는 부분

  • F.normalize(z1, dim=-1) 는 벡터를 L2 norm으로 정규화 합니다. \(\frac{z_i}{ \| z_i \|_2}\)
    • z / torch.sqrt(z**2).sum()) 이것과 동일한 의미입니다.
  • L2 norm으로 정규화된 벡터는 서로 dot product (내적)을 계산하면, cosine similarity가 됩니다.

label 생성하는 부분

  • z1: 문장 A, B, C, D → [z1[0], z1[1], z1[2], z1[3]]
  • z2: 위 문장 A, B, C, D 를 dropout해서 다시 만든 버전 → [z2[0], z2[1], z2[2], z2[3]]

각 anchor embedding (row)에 대해
이 anchor의 정답 positive가 z의 몇 번째 위치에 있는지를 알려주는 label 배열을 만들기

쉽게 말하면, 서로의 index를 맞추는 작업.

def simcse_loss(self, z1, z2):
    """
    Compute SimCSE loss using InfoNCE
    Args:
        z1: [batch_size, hidden_size] - first representations
        z2: [batch_size, hidden_size] - second representations (dropout augmented)
    Returns:
        loss: contrastive loss
    """
    # Normalize embeddings
    z1 = F.normalize(z1, dim=-1)
    z2 = F.normalize(z2, dim=-1)

    # Concatenate z1 and z2
    z = torch.cat([z1, z2], dim=0)  # [2*batch_size, hidden_size]
    
    # Compute similarity matrix (matmul is dot product)
    sim_matrix = torch.matmul(z, z.t()) / self.temp  # [2*batch_size, 2*batch_size]
    
    batch_size = z1.size(0)
    
    # Create labels for positive pairs
    # For i-th sample in z1, its positive is (i+batch_size)-th sample in z2
    # For i-th sample in z2, its positive is (i-batch_size)-th sample in z1
    # labels = [4, 5, 6, 7, 0, 1, 2, 3] for batch_size=4
    labels = torch.arange(2 * batch_size, device=z.device)
    labels[:batch_size] += batch_size  # z1's positives are in z2
    labels[batch_size:] -= batch_size  # z2's positives are in z1
    
    # Mask out diagonal (self-similarity)
    # We don't want to compare a sample with itself
    # tensor([[  -inf, 0.0000, 0.9939, 0.1104],
    #    [0.0000,   -inf, 0.1104, 0.9939],
    #    [0.9939, 0.1104,   -inf, 0.2195],
    #    [0.1104, 0.9939, 0.2195,   -inf]])
    mask = torch.eye(2 * batch_size, device=z.device).bool()
    sim_matrix = sim_matrix.masked_fill(mask, -float('inf'))
    
    # Compute cross-entropy loss
    loss = F.cross_entropy(sim_matrix, labels)
    
    return loss