A friendly intro to semi-supervised learning

5 min read

How to additionally use unlabeled data in your deep learning pipeline to improve the performance of your model

Lukas HuberPhoto by Bernd Dittrich on Unsplash

In most academic or hobby deep learning challenges you are most likely faced with well-known and already labeled datasets which can directly be fed into your own model. To make things easier, researchers often rely on standard datasets like CIFAR-10, SVHN, ImageNet, and many others to evaluate their architectures. This has the benefit that the expensive data collection and labeling have already been done by someone else so you do not have to worry about it.

In practical use cases, however, where you need to collect your own data, things might become more cumbersome. The data collection process is often one of the most time-consuming and critical parts of a deep learning project. You have probably heard the phrase

“Garbage in, garbage out.”

before, which is especially true for neural networks. Mislabeled data points or biases in the training data can drastically reduce the predictive power of your model and render the whole pipeline useless. As a result, it often is hard to obtain enough high-quality samples to train a neural network within reasonable cost and time.

If one takes a closer look at the process of obtaining the training data, one will quickly realize that not all parts scale equally well. In many cases, it is easy to record thousands and thousands of samples. Only the following labeling process makes it so costly to obtain many samples. Consequently, in practice, only a fraction of the available data is actually used to train neural networks effectively throwing away all unlabeled samples. Such training is often referred to as supervised learning. Contrary, using only data points without any labels (e.g. for clustering) is called unsupervised learning.

With models getting bigger and bigger requiring a vast amount of data to even converge, the question inevitably arises whether these unlabeled samples can be leveraged during the training process.

Such a training schedule is called semi-supervised learning as it combines the core concepts of supervised and unsupervised learning.

This article will provide a friendly introduction to semi-supervised learning and explain its core concepts. Let’s get started!

Semi-supervised learning in action. The bold line shows the decision boundary obtained by supervised learning. The dotted line shows the boundary for the semi-supervised case. The dots are the unlabeled data points and the triangles/plus signs are the labeled ones. Figure taken from van Engelen et al. (2018)

The figure above shows all three learning schemes in action. The circles stand for the unlabeled data points of both classes. The circles and triangles correspond to the labeled samples. In order for the unlabeled samples to be helpful, we have to assume that they still contain useful information for us. More mathematically:

The underlying marginal distribution p(x) should provide useful information about the posterior p(y|x).

In order for semi-supervised training to work we have to rely on three main assumptions:

Smoothness assumption

It states that if two samples x1 and x2 are close in the input space, they should share the same label. For instance, let there be a dataset describing cars consisting of weight and fuel consumption. Samples that have small values for both features are likely to represent compact cars, while ones with high values tend to correspond to SUVs. This assumption comes in handy when we also consider unlabeled data since we expect them to share the label of their closest labeled neighbor.

Low-density assumption

From the smoothness assumption, we can directly derive another premise. The decision boundary between the classes should lie in a low-density region of the input space. This means, that it should lie in an area with few labeled and unlabeled samples. If it would lie in a region with a high density, the smoothness assumption would be violated, since samples that are close in the input space would no longer share the same label.

Manifold assumption

Data for Machine Learning tasks often is high dimensional. Although, not all features show the same level of variance, which makes them less useful for the model. As a result, the high-dimensional data often lies on a much lower-dimensional manifold (=a structure in a space). This information can be used to infer the classes of unlabeled samples.

These three assumptions build the foundation of almost all semi-supervised learning algorithms.

One thing about recent semi-supervised learning algorithms is, that they are all based on one of two paradigms (sometimes even both).

The first paradigm is called pseudo-labeling, which uses the network itself to generate ground truth labels for the unlabeled data. To do this, the model is often pretrained with the fully labeled subset, that one needs to obtain. The unlabeled samples are then fed into the network and their class predictions are recorded. If the largest class probability of a sample exceeds a set threshold, the corresponding class is used as ground truth. These samples can then be used to train the model in a supervised fashion. As the performance of the model gets better and better, the artificially obtained labels can iteratively be refined using the very same technique.

The second paradigm is called consistency regularization and trains the model to output similar predictions for when fed two slightly different versions of the same image. In many cases, these perturbed versions of the original image are often obtained using data augmentation methods such as rotation, shifting, contrast changing, or many other techniques. Such training allows the model to generalize better and be more robust. Since we simply enforce similar predictions, no class labels are needed in this case. Hence, the unlabeled data can be used as it is.

Semi-supervised learning can be used in applications where data is easy to obtain but hard to label. It makes use of both labeled and unlabeled data to generate a model that is usually more powerful than one trained in the standard supervised fashion. These algorithms are often based on pseudo labeling and/or consistency regularization.

One disclaimer, though: Even if semi-supervised training often improves over standard supervised training there is no guarantee that this is the case for your very own application. Research has shown that it may even lead to a degradation of performance in some limited cases.

If you want to learn more about actual implementations of semi-supervised learning algorithms, such as FixMatch, stay tuned for my future articles!

You may like these posts

Post a Comment

hey there, great job keep on interacting