Transfer learning in healthcare: Tackling the curse of dimensionality
You are a parent and wish to teach your 8-year-old boy how to play the violin? Ok. But does this have anything to do with artificial intelligence (AI)?...
Welcome to the world of transfer learning.
Recent scientific experiments have shown those very young babies — as young as 9-months old — that learn music can significantly improve many of their cognitive functions. These can include their future language acquisition. Children who learn how to play music young get better verbal and language learning skills than the ones who don’t. This is because they gain enhanced sound representation abilities and modify their brain connectivity.
For adults, learning a language is a great example. Have you not noticed that the more languages you learn, the easier it gets? Of course, it is much simpler to learn Italian when you already know Spanish. However, that’s mostly related to vocabulary similarities. Now, less obvious…why do we get better at learning Russian when we already know English and Arabic. These languages do not share any vocabulary similarity? Probably because: 1) we have captured very general principles and abstract structures common to various languages, and 2) we have just learned… how to learn.
At the end of the day, that makes of us much faster learners to anything else.
Let’s climb one level higher in the tower of abstraction. We understand something else: logic is also fundamental to transfer knowledge. The more you study mathematics, the more you develop your logic skills. You will learn faster new fields where logic reasoning is important, such as philosophy. Just look at all these great thinkers, Leibniz, Pascal or Descartes: they were all mathematicians!
One answer is transfer learning. A key hidden secret behind the success of deep learning. And probably the next revolution in AI.
What is it exactly? Transfer learning is a domain of AI. It focuses on the ability of a machine-learning algorithm to improve learning capacities on one given dataset through the previous exposure to a different one.
What is the conclusion here?
Learning one thing in one domain can help you learn faster many other things in different domains. This is due to mixing up common abstract structures and logic. You get it: this idea had to be explored by AI!
Although deep learning algorithms have opened the path for a much stronger AI, current systems are still very simple. To train a deep learning based AI from scratch to recognize a dog vs a cat, you still need to show first to the machine thousands and thousands of examples of labeled images (here cats and dogs). Teach this to your kid, and just a handful examples would be enough, right? So what technology could be created in the AI world to perform as good as kids?
Let’s take two examples:
1) We wrote above that learning some languages can help to learn other ones.
You probably think: this is only true for humans. No! It is also totally true for machines. Famous translation algorithms, such as Google Translate, for example, are very sensitive to this phenomenon! Indeed, if you want to train such an algorithm to translate, let’s say, from English to French, then you should also train it to translate Spanish to Russian, or even Italian to german. You would realize that the quality of English-to-french translation would be much improved! This is stunning on a philosophical level. It means that high cognitive phenomenons are also starting to be incidentally observed in the latest AI systems.
2) This second example is also very exciting. How transfer learning can help AI systems to learn from a small number of cases.
It is probably the most used story of transfer learning practice at the moment and one of the hidden reasons why deep learning is such a success. Let’s suppose you want to create an algorithm that is able to differentiate a breast from a lung cancer based on medical images (such as MRI or pathology images). In principle, deep learning algorithms are very powerful for image recognition and should be the best choice. Using them would require 100k to 1M images in order to achieve good accuracy. But for most medical applications, this is likely to be a huge problem. This is because most datasets that physicians and researchers can build are typically around only a few hundred images, maximum!
At Owkin, we deal sometimes with these small datasets of a hundred of ultra high resolution (gigapixel) medical images, making our data matrix super horizontal (few samples of order hundred and very large number of features of order billions). So what can we do? We proceed in two steps:
Step 1.
We import a huge dataset with 1M images of 10k categories (cats, trucks, computers, etc.) called ImageNet and train a deep convolutional neural network (CNN) to classify the images of this source dataset.
Step 2.
We use CNN as a starting point and train it to classify the medical images dataset (target dataset). This procedure is called fine-tuning from a warm restart. We have another option. Extracting the representation of the images at the penultimate layer in the deep neural network and train any standard classification algorithm on these new features. This is called feature extraction. There are many others, such as penalization of the weights by similarity or multi-task learning for instance!
Having this possibility to transfer knowledge and experience from one task — recognize a cat from a plane — to another — distinguish benign from malignant tumors — is doubtlessly one of the key features that can explain the recent success of deep learning technologies.
Indeed, deep learning architecture is very well suited for the transfer learning approach. Probably much more than any other type of machine learning algorithm.
Cool, we now have learned what is transfer learning!
We have shown that it is very powerful in the context of deep learning and can explain part of its success. But in addition, to be important for deep learning, why is this technology so important for the future of AI?
Well, it’s important for two main reasons.
Reason one:
The first reason is very simple, yet primordial, and we have already shown an example above: transfer learning brings the power of machine learning to small datasets. It is particularly important in the context of medical studies since the current trend is to gather more and more data for each patient, although the number of patients with clean annotations remains usually quite small. It is a big challenge to deal with these horizontal datasets for which standard statistical methods usually fail. To go beyond classical feature selection and regularization methods, we need to integrate external knowledge, which is precisely what a human would do. Therefore one solution would be to look for other datasets, somehow related in their type or statistical distributions, and use a transfer learning approach to bring the information they contain into the problem you want to solve.
Reason two:
With classical statistical methods, you always start from a white page each time you address a new question. Burdensome, isn’t it? With transfer learning, AI systems can get closer / mimic to what is so precious in the human intellectual adventure: we think collectively, we share knowledge and reasoning, and at the end of the day, we are dwarfs standing on the shoulders of giants. This ability of AI systems to collaborate is precisely the second reason why transfer learning is so important for the future of AI. We are entering a new era where AI systems can mutually cross-fertilize. Each time a new problem is solved with a machine learning approach, our collective resource of AI algorithms is augmented so that we are creating a library of algorithms, sometimes even called a zoo of models.
With transfer learning, it becomes possible to pitch pretrained algorithms in these libraries to enhance the predictive power of many other AI systems. It is incredibly powerful and contributes to a newly emerging system of global and collaborative artificial intelligence.
Why?
Well, deep neural networks encode the information contained in the image in a hierarchical way. The lowest level layers detect low-level features such as contours and textures. This is a basic and universal work in almost any image recognition problem. While the highest level layers detect more complex objects and abstract concepts.
Therefore, during Step 1, all the lowest level layers are trained to discover basic shapes and textures on the huge dataset ImageNet. A set of skills that will be reused during Step 2. The neural network does not need to learn from scratch these basic skills and can build upon what has been learned on the huge dataset.
During Step 2, only the highest level layers need to be updated to adjust to the task. Whereas the lower level layers can remain unmodified as their job is universal and less specific.
So when you hear or read about all these new technology startups that use deep learning algorithms to recognize faces or guide self-driving cars, you can safely bet that they are using a transfer learning strategy at some point of their research!
We love this idea of a collaborative AI at Owkin.
This is why we are building the first platform to enable every researcher or doctor in the world to create and share a new classification algorithm, empowered by transfer learning technologies. Users just need a unique dataset of medical images that have at least two different labels, whether it is for diagnosis, prognosis or drug response prediction. We’ll talk about that in more details in our next post !
To conclude, it is time to go beyond the classical approach where each problem would be solved independently, starting each time again from a white page. Let’s move on, towards a global platform where transfer learning can make all these algorithms share their secrets, help each other, and at the end accelerate medical discoveries.