Adding an attention mechanism to RNNs

This article is an excerpt from the book Machine Learning with PyTorch and Scikit-learn is the new book from the widely acclaimed and bestselling Python Machine Learning series, fully updated and expanded to cover PyTorch, transformers, graph neural networks, and best practices.



Sponsored Post

 
In this article, we discuss the motivation behind developing an attention mechanism, which helps predictive models to focus on certain parts of the input sequence more than others, and how it was originally used in the context of RNNs.  

 
Adding an attention mechanism to RNNs
 

This article is an excerpt from the book Machine Learning with PyTorch and Scikit-learn is the new book from the widely acclaimed and bestselling Python Machine Learning series, fully updated and expanded to cover PyTorch, transformers, graph neural networks, and best practices. 

 

Attention helps RNNs with accessing information

 
To understand the development of an attention mechanism, consider the traditional RNN model for a seq2seq task like language translation, which parses the entire input sequence (for instance, one or more sentences) before producing the translation, as shown in Figure 1: 

Adding an attention mechanism to RNNs
Figure 1: A traditional RNN encoder-decoder architecture for a seq2seq modeling task

 

Why is the RNN parsing the whole input sentence before producing the first output? This is motivated by the fact that translating a sentence word by word would likely result in grammatical errors, as illustrated in Figure 2: 

 

Adding an attention mechanism to RNNs
Figure 2: Translating a sentence word by word can lead to grammatical errors

 

However, as illustrated in Figure 2, one limitation of this seq2seq approach is that the RNN is trying to remember the entire input sequence via one single hidden unit before translating it. Compressing all the information into one hidden unit may cause loss of information, especially for long sequences. Thus, similar to how humans translate sentences, it may be beneficial to have access to the whole input sequence at each time step. 

In contrast to a regular RNN, an attention mechanism lets the RNN access all input elements at each given time step. However, having access to all input sequence elements at each time step can be overwhelming. So, to help the RNN focus on the most relevant elements of the input sequence, the attention mechanism assigns different attention weights to each input element. These attention weights designate how important or relevant a given input sequence element is at a given time step. For example, revisiting Figure 2, the words "mir, helfen, zu" may be more relevant for producing the output word "help" than the words "kannst, du, Satz." 

The next subsection introduces an RNN architecture that was outfitted with an attention mechanism to help process long sequences for language translation. 

 

The original attention mechanism for RNNs

 
In this subsection, we will summarize the mechanics of the attention mechanism that was originally developed for language translation and first appeared in the following paper: Neural Machine Translation by Jointly Learning to Align and Translate by Bahdanau, D., Cho, K., and Bengio, Y., 2014, https://arxiv.org/abs/1409.0473.

Given an input sequence , the attention mechanism assigns a weight to each element (or, to be more specific, its hidden representation) and helps the model with identifying which part of the input it should focus on. For example, suppose our input is a sentence, and a word with a larger weight contributes more to our understanding of the whole sentence. The RNN with the attention mechanism shown in Figure 3 (modeled after the previously mentioned paper) illustrates the overall concept of generating the second output word:  

 

Adding an attention mechanism to RNNs
Figure 3: RNN with attention mechanism

 

The attention-based architecture depicted in the figure consists of two RNN models.  
Learn more with Machine Learning with PyTorch and Scikit-Learn by Sebastian Raschka, Yuxi (Hayden) Liu, and Vahid Mirjalili.Â