An Overview of Contrastive Learning

SimCLR, MoCo, SwAV, BYOL, CLIP, DeepCluster, PIRL, Barlow Twins. Understanding Contrastive Learning and its mainstream works.

Rice Yang
12 min readDec 2, 2022
Source: Why You Need to Stop Comparing Yourself With Others — Lolly Daskal | Leadership

Contrastive Learning (CL) has become a hot spot in both the academic and business field for its wide potential application. With Contrastive Learning, we can produce the pre-trained models and learn effective representation with a powerful generalization which can accelerate the downstream development. Contrastive Learning is now already applied to many practical applications, e.g. video platforms, social networks, e-commercial, etc. It is also a probable question when interviewing a machine learning engineer in a related field.

This article tried to categorize the components of Contrastive Learning from top to down and summarize the mainstream methods into a single table to compare their difference.

Table of Contents

  • The Concept of Contrastive Learning
  • The Five Components in Contrastive Learning
  • Parallel Augmentation: End-to-End, Momentum Encoder
  • Architecture: End-to-End, Memory Bank, MoCo, Clustering
  • Loss Function: NCE, InfoNCE
  • Data-Model: CLIP
  • A Summary of Mainstream Methods
  • Contrastive Learning v.s. Self-Supervised Learning

The Concept of Contrastive Learning

Like VAE and GAN, Contrastive Learning could be applied as a kind of Unsupervised Learning. But when we do Contrastive Learning on unlabelled data, we usually categorize it as Self-Supervised Learning instead of Unsupervised Learning.

If the concept of VAE is learning itself, then the concept of Contrastive Learning is learning itself to be similar to itself.

Start from Facial Recognition

The facial recognition model FaceNet in 2015 has a similar idea to Contrastive Learning. Before the FaceNet, supervised learning on facial recognition only push the model to keep the top similarity within the same person. FaceNet proposes triplet loss to make the models have to keep low enough similarity between different persons, too.

In traditional facial recognition, the model will only aim to keep high similarity for all pictures of Obama by making the feature points close enough in feature space (right). FaceNet not only keeps that design but also pushes Trump’s feature point away from Obama’s to keep them in low similarity. Source: Building a dog search engine with FaceNet

Since Contrastive Learning is not popular at that time, there are no such nouns related to Contrastive Learning in the paper. But today we can see that it focuses on maximizing the similarity within the class and minimizing the similarity between classes. The key difference is that the in-class data is augmented in Contrastive Learning, instead of collected.

Data Augmentation

Data augmentation can generate multiple data from a single one without breaking its semantic content. For example, follows picture shows the data augmentation methods and the corresponding result. They are all still a dog.

Some methods of data augmentation. Source: SimCLR Paper

For implementation, we usually predefined a fixed probability value to trigger a certain data augmentation method. In the GitHub code from SimCLR, the augmentation will be triggered if the random number is less than probability p.

Data augmentation helps us to increase the data variety, get a couple of times of data for free, and boost the final accuracy of the learned model. There are also some approaches for automatic data augmentation, e.g. AutoArgument, RandArgument, etc. PyTorch already implemented some automatic methods for us to use. See “Autumatic Augmentation Transforms.

From Data Augmentation to Contrastive Learning

Here is the idea: if we only have only one unlabelled data in each class, can we utilize data augmentation to enlarge data size and transform it from unsupervised learning to supervised learning?

The answer is YES. Suppose we got N data and augment each for K times to get N×K data with N classes and K data for each class. That’s the basic idea of Contrastive Learning.

Two data will always be defined into different classes, so it may cause some error labels. For example, two photos of Obama will be treated as different identities. This is an inevitable problem because it is not Contrastive Learning if you manually fix the labels by hand with additional prior knowledge, it’s a Supervised Learning setting. To fix it, Contrastive Learning usually needs massive of data to reduce or minimize the negative effect of error labeling.

You may ask: What’s the meaning of training a classifier that treats a sample as a class? If you train on big data, the learned models will perform a very powerful generalization on feature representation and can be applied to downstream tasks with zero-shot or few-shot settings. Take facial recognition as an example. If Biden’s photo didn’t exist in the training set, the models are still able to recognize Biden by comparing the similarity of extracted facial features.

So, what’s the Contrastive Learning?

Contrastive Learning aims to learn a general data representation with powerful generalization. The learned representations have high similarity for similar data and low similarity for different data.

The Five Components in Contrastive Learning

After my study, I think there are 5 principal components of Contrastive Learning. We can derive different methods by permuting and combining possible selections among principles.

  1. Data Augmentation. This is the base stone of Contrastive Learning. Proposing a new simple method may contribute a lot to entire machine learning, e.g. MixUp, CutMix.
  2. Parallel Augmentation. How to extract feature representations parallelly for different augmented data. Mainly there are 2 designs for it: end-to-end and momentum encoder.
  3. Architecture. It’s about how to manipulate the positive and negative sample pairs for calculating the loss. There are almost 4 categories: end-to-end, memory bank, MoCo, and clustering.
  4. Loss Function. The different formulas may affect the final result. The most important may be InfoNCE.
  5. Data-Modal. Some methods are designed for certain types of data modal or multimodal data.

Parallel Augmentation

In Contrastive Learning, every data has the original anchor data and augmented noise data. Parallel Augmentation is closely related to the following architecture and focuses on how to extract features from noise or anchor data.

End-to-End Encoder

The end-to-end encoders extract features from data within a batch with the same network structure and weights. For better understanding, its design is the same as the siamese network.

The most straightforward work in this category is SimCLR. SimCLR generates 2 noise data from 1 anchor data and permutes them to calculate the similarities for each data pair. Considering batch size N, we get 2N×N data pairs including 2N positive pairs and 2N×(N-1) negative pairs.

The SimCLR process animation. Source: google-research/simclr

All images within a batch share the same CNN weights. The gradient backword-propagate and update the CNN weights via all positive and negative pairs.

Momentum Encoder

In physics, the noun momentum is used to describe the measurement of inertia. In machine learning, momentum sometimes stands for the concept of moving average.

The paper BYOL may help us better understand the idea of momentum encoder. The following picture shows the design of BYOL. Note that BYOL calculates the loss with only positive pairs.

BYOL illustration. the sg stands for stop-gradient and means the gradient will stop and leave the parameters ξ unchanged. Source: BYOL Paper

The authors use symbols f, g, q to present different neural networks and use subscripts θ, ξ to present the weight parameters. The above is the online networks with parameters θ and the below is the target networks with parameters ξ. When θξ, it’s an equalism to an end-to-end design. When ξ is set to the moving averaging of θ, f( · ; ξ) is a momentum encoder.

So a momentum encoder indicates a moving averaged version from an end-to-end learned encoder. This idea is similar to batch normalization that uses moving averaged mean and variance to estimate the hyper-parameters mean, var.

The benefit of the momentum encoder is to avoid over-fitting since it’s not easy to eliminate the difference among positive pairs by gradient descent only. The reason is the loss is generated from both data augmentation and momentum encoder, both are not touchable by gradient descent. The results of BYOL also implicitly show that it’s robust enough to ignore the loss from negative pairs.

Architecture

The noun architecture here means not only the basic network architecture like ResNet and VGG but also the strategy to sample data from both positive and negative pairs.

End-to-End

It directly permutes and combines the possible positive and negative pairs within a minibatch. Besides SimCLR, the simplest method may be Barlow Twins which permutes the data from 2 streams into a cross-correlation matrix directly and try to solve it to be an identity matrix I.

The training of Barlow Twins. Source: Barlow Twins Paper

There is a common problem existing among all end-to-end architectures like SimCLR and Barlow Twins: the negative pairs are too less. The relations of negative pairs, negative classes, and batch size are all linear with end-to-end architectures. So we need to train SimCLR with a super big batch size or the performance will be incredibly worse.

The ideal formulation is to calculate the loss with all possible negative pairs to cover all negative classes, which means batch size = class size = data size, and it’s impossible. When we got millions of data, it’s impossible to cache millions of features in memory and calculate millions×millions of similarities due to the hardware limitation.

Memory Bank

The memory bank is a simple design to fix the above problem. Since we always calculate the similarity in the feature domain instead of the data domain, we can catch all the data features in a cache table. This design can help to reduce computing by replacing the feature extraction procedure with looking at the table. The cache table is called the memory bank.

PIRL is a simple and successful work in recent years. It caches all encoded representations into the memory bank and calculates the loss for negative samples by looking at the memory bank. Note that the representations in the memory bank are the moving averaged representations, not the hard copies.

Source: PIRL Paper

MoCo: Moving Contrast

MoCo is an advanced design combining the memory bank and momentum encoder. MoCo aims to fix the problem that the memory and computation complexities of the memory bank architecture are still large. It uses a size-limited FIFO queue to cache the past representations and utilize them as negative data.

The FIFO queue helps reduce the memory and computing complexity from the data size level to the queue size level. On the other hand, this method has less accuracy than the memory bank. According to the experiments, the reported accuracies of MoCo are all a little worse than the PIRL (memory bank).

The basic illustration of MoCo. Source: MoCo Paper

Clustering

The clustering has been leading Unsupervised Learning for the past decades. Recently it has been applied to Contractive Learning to solve the problem of too large negative classes.

Source: DeepCluster Paper

DeepCluster directly explains how clustering could be applied in deep learning architecture. As the above illustration, it designs an additional k-means clustering task in the feature space and marks pseudo-labels for every data. The loss can be calculated with the supervision of pseudo-labels.

Compare SwAV with end-to-end architecture. Source: SwAV Paper

SwAV is the most representative method among the clustering-based architectures. The above figure illustrates that the SwAV uses the prototypes (C) to cache the center features for clusters. It also introduces the codes (Q) to represent the extracted feature z by the similarities to each cluster and use it to calculate the loss. By the way, the same as BOYL, SwAV is a design that only considers the positive pairs in loss functions.

Why SwAV introduces the codes Q instead of that directly calculated by inner dot z and C? The reason is that the SwAV only considers the positive pairs which may direct the gradient-descent-based optimizers to update all prototypes to be the same value, for example, a mode collapse. To avoid mode collapse, it introduces Q and solves it individually with Sinkhorn-Knopp algorithm, instead of the gradient-descent-based optimizers.

In short, SwAV keeps maximum similarity within clusters by gradient descent and minimum similarity between clusters by Sinkhorn-Knopp algorithm.

Moreover, the SwAV also introduces the multi-crop to utilize more low-resolution noise data to increase data size and boost performance.

Loss Functions

Most of the losses in machine learning could be introduced into Contrastive Learning. We can pay more attention to two loss functions designed for Contrastive Learning: NCE and InfoNCE.

NCE: Noise-Contrastive Estimation

NCE is an elder method since it was been proposed in 2010 but still effective in Contrastive Learning. To calculate a loss for a data sample x via NCE, a positive sample x+ and a set of negative samples X- are selected to calculate the similarity and make it large between x and x+ and keep them small between x and X-. The NCE loss could be present as the following formula:

The h in the first row stands for a probability estimation function via the softmax formula. The loss in the second row is a standard form of binary cross entropy (BCE) loss. In short, NCE is a BCE loss calculated by estimating the probability from similarities with a positive sample and a set of negative samples.

Finally, the similarity function is usually designed as cosine distance on normalized vectors and divided by a temperature parameter.

InfoNCE

The InfoNCE was introduced in the paper CPC in 2018. It’s a cross-entropy version of NCE, which means, a positive sample x+ and a set of all samples X are selected and the loss is calculated with cross-entropy:

According to the paper, Understanding the Behaviour of Contrastive Loss, InfoNCE is a hardness-aware loss that allowed the learning to focus more on the hard samples. It is a nice property since lots of machine learning studies show that the performance can be improved with hard sample mining.

Data-Modal

In this article, all the Contrastive Learning methods are based on image input and augmentation. But if our data is multi-modal, we can apply Contrastive Learning and estimate similarities between different modalities of a data sample, like CLIP. By the way, the CLIP use cross-entropy loss on cross-modal features, so actually it’s an InfoNCE loss.

CLIP Paper

A Summary of Mainstream Methods

The following table summarizes all the methods mentioned in this article for a general overview. Limited by my knowledge, lots of works are not listed on the table.

Summary of Contrastive Learning methods. the “+” and “-” in the loss function column indicates that it uses “positive” or “negative” pairs.

Moreover, the parameters-accuracy relationship for all mainstream methods is plotted in the following figure referenced from the paper, A Survey on Contrastive Self-Supervised Learning. The figure shows that the accuracy of SwAV is better than the others and very close to Supervised Learning.

The comparison of CL methods on ImageNet Top-1 accuracy. Source: A Survey on Contrastive Self-Supervised Learning

Note that you can barely reproduce the accuracy of SimCLR because it uses a super big batch size (=8192) unless you have very powerful machines. Lower batch size will harm the accuracy for the number of negative pairs is not sufficient.

Contrastive Learning v.s. Self-Supervised Learning

When we apply Contrastive Learning on unlabeled data, it is a branch of Self-Supervised Learning and also the most powerful branch. It can help us to get a powerful data representation with big data. The representations allow us to implement further tasks by similarity matching, e.g. Facial RecognitionPerson Re-Identification. The learned encoder can also be applied to transfer learning like few-show or zero-shot learning.

On the other hand, Self-Supervised Learning has other branches besides Contrastive Learning. It’s totally wrong if you think “Self-Supervised Learning = Contrastive Learning”.

--

--

Rice Yang
Rice Yang

Written by Rice Yang

A Senior Engineer in AI. Experienced in NVIDIA, Alibaba, Pony.ai. Familiar with Deep Learning, Computer Vision, Software Engineering, Autonomous Driving

No responses yet