The example below shows implentation of a transformer block as a Keras layer, which can be used in place of an LSTM or GRU layer to process sequential input. We take the mean of transformer outputs at each time step and use a feed forward network on top of it to classify text.

  • The transformer shown in the example is a small one; 2 attention heads, and small dimensions for projection heads and FFN. It achieves validation accuracy of 87% on IMDB in less than a minute on CPU.
  • The model trains faster than an LSTM based classifier with a similar number of parameters and gives slightly better results.
  • Positional embedding are implemented using a standard keras.layers.Embedding layer, which according to the original paper produced nearly identical results to the sinusoidal version.


import tensorflow as tf
from tensorflow import keras

Implement multi head self attention as a Keras layer

class MultiHeadSelfAttention(keras.layers.Layer):
    def __init__(self, embed_dim, num_heads=8):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert (
            embed_dim % num_heads == 0
        ), "embedding dimension not divisible by num heads"
        self.projection_dim = embed_dim // num_heads
        self.wq = keras.layers.Dense(embed_dim)
        self.wk = keras.layers.Dense(embed_dim)
        self.wv = keras.layers.Dense(embed_dim)
        self.combine_heads = keras.layers.Dense(embed_dim)

    def attention(self, q, k, v):
        score = tf.matmul(q, k, transpose_b=True)
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_score = score / tf.math.sqrt(dk)
        weights = tf.nn.softmax(scaled_score, axis=-1)
        output = tf.matmul(weights, v)
        return output, weights

    def separate_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, x):
        # x.shape = [batch_size, seq_len, embedding_dim]
        batch_size = tf.shape(x)[0]
        q = self.wq(x)  # (batch_size, seq_len, embed_dim)
        k = self.wk(x)  # (batch_size, seq_len, embed_dim)
        v = self.wv(x)  # (batch_size, seq_len, embed_dim)
        q = self.separate_heads(
            q, batch_size
        )  # (batch_size, num_heads, seq_len, projection_dim)
        k = self.separate_heads(
            k, batch_size
        )  # (batch_size, num_heads, seq_len, projection_dim)
        v = self.separate_heads(
            v, batch_size
        )  # (batch_size, num_heads, seq_len, projection_dim)
        attention, weights = self.attention(q, k, v)
        attention = tf.transpose(
            attention, perm=[0, 2, 1, 3]
        )  # (batch_size, seq_len, num_heads, projection_dim)
        concat_attention = tf.reshape(
            attention, (batch_size, -1, self.embed_dim)
        )  # (batch_size, seq_len, embed_dim)
        output = self.combine_heads(
        )  # (batch_size, seq_len, embed_dim)
        return output

Implement transformer block as a layer

class TransformerLayer(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerLayer, self).__init__()

        self.att = MultiHeadSelfAttention(embed_dim, num_heads)
        self.ffn = keras.Sequential(
                keras.layers.Dense(ff_dim, activation="relu"),

        self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, x, training):
        attn_output = self.att(x)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)

        return out2

Implement embedding layer

Two seperate embedding layers, one for tokens, one for token index (positions).

class EmbeddingLayer(keras.layers.Layer):
    def __init__(self, maxlen, vocab_size, emded_dim):
        super(EmbeddingLayer, self).__init__()
        self.token_emb = keras.layers.Embedding(
            input_dim=vocab_size, output_dim=emded_dim
        self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=emded_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

Create classifier model using transformer layer

Transformer layer outputs one vector for each time step of your input sequence. Here, we take the mean across all time steps and build a two layered feed forward network on top of it.

class TransformerClassifier(tf.keras.Model):
    def __init__(self, maxlen, vocab_size, embed_dim, ff_dim, num_heads):
        super(TransformerClassifier, self).__init__()
        self.emb = EmbeddingLayer(maxlen, vocab_size, embed_dim)
        self.transformer = TransformerLayer(embed_dim, num_heads, ff_dim)
        self.prehead = keras.layers.Dense(20, activation="relu")
        self.dropout1 = keras.layers.Dropout(0.05)
        self.dropout2 = keras.layers.Dropout(0.05)
        self.head = keras.layers.Dense(2, activation="softmax")

    def call(self, x, training):
        x = self.emb(x)
        x = self.transformer(x, training)
        x = tf.math.reduce_mean(x, axis=1)
        x = self.dropout1(x, training=training)
        x = self.prehead(x)
        x = self.dropout2(x, training=training)
        x = self.head(x)
        return x

Download and prepare dataset

max_features = 6000  # Only consider the top 20k words
maxlen = 200  # Only consider the first 200 words of each movie review
(x_train, y_train), (x_val, y_val) =
print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen)

Downloading data from
17465344/17464789 [==============================] - 2s 0us/step
25000 Training sequences
25000 Validation sequences

Train and Evaluate

model = TransformerClassifier(
    maxlen=maxlen, vocab_size=max_features, embed_dim=32, ff_dim=32, num_heads=1
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
history =
    x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val)

Epoch 1/2
782/782 [==============================] - 29s 37ms/step - loss: 0.3767 - accuracy: 0.8248 - val_loss: 0.3363 - val_accuracy: 0.8561
Epoch 2/2
782/782 [==============================] - 30s 39ms/step - loss: 0.2497 - accuracy: 0.8992 - val_loss: 0.2959 - val_accuracy: 0.8716