I still experience a moment of excitement, whenever I start on a new ml project and look at my prediction results for the first time, not sure whether this will be a good model or just garbage. I am not talking about metrics, but just about the actual prediction outputs. There is this thing that so many people working with algorithms may now take for granted, that actually sometimes touches me in this particular moment, when the predictions do ‘make sense’: A machine actually learned to produce something meaningful. After all discussions about ml research coming to an halt, we should now and then show a little humility and realize that what we achieved in the past few years is actually pretty amazing.
If you take this naive fascination of your pet machine doing something meaningful, the idea of multimodal models can be naturally connected with it. When you assume your machine is some kind of intelligent system (I am not talking about general artificial intelligence here), making it multimodal is like turning a blind and deaf worm into a multi-sense primate. A typical off-the-shelf machine learning model will only be good at on thing, langue, vision, tabular and so on. But multimodal models go beyond that. Just as humans have five senses (vision, touch, taste, smell and hearing), a multimodal model may ingests information belonging to different sense modalities.
The most popular multimodal models combine computer vision and natural language processing, which, strictly speaking, is not multimodal, since text and images can be perceived through vision. However, in machine learning the idea of multimodality is broader and in most cases refers to the idea of using different data types (e.g. images, texts, tabular data) within the same model.
Let’s look at a more concrete example to see, how multimodal models can become useful. Imagine a dataframe like this:
It contains all kinds of data: numerical values, categorical values, timestamps, texts and images. Traditional ml models would face the challenge that they are specialized in a particular kind of data handling: language models in handling text (e.g. mlp n-gram or finetuning BERT), computer vision models in handling images (eg. training a CNN in TensorFlow or PyTorch), time series models in handling sequential data, and perhaps tree-based models in handling ordinary tabular data (eg. XGBoost). This leaves the data scientist with the challenge of either having to choose one model over the other, or finding a way to combine them in a reasonable way.
Here, multimodal models come into play: While adding more data may help improving model performance, with multimodal models you are not only adding more data, you are broadening the range of information your model can learn from, all within one unified model architecture. The model may learn a joint representation of different modalities, which should yield more meaningful representations. Think of it as different ways in which the same thing can be comprehended that are combined into a unified representation of that thing.
In the retail industry for example, a product image may contain a lot of information about its shape and colours, while the product description contains information about the material and the functionality of the product. Classifying or comparing these objects will likely become easier, when all components are taken into consideration.
Build a multimodal model
It is actually fairly easy to build a multimodal neural network. I will give you two examples of how a multimodal model can be defined and trained. Keep in mind that this is just for demonstration purposes. The models will be very simple. I am also not considering the broader context of machine learning workflows (reproducible model experimentation, hyperparameter tuning, model deployment, etc.) here.
For demonstration purposes, I will use the
movies_metadata.csv from Kaggle. To keep it simple, I will only consider the following columns as features:
|100||1350000||4.60786||3.89757e+06||105||A card sharp and his unwillingly-enlisted friends need to […] door.|
|10000||0||0.281609||0||116||A group of tenants […] were.|
The two ‘modalities’ we will consider here, are just structured data and text as two input sources. I know, it’s not that multimodal, but using images wouldn’t be truly multimodal either 😄 (if you think of modalities in terms of modalities of perception).
The objective of the classification task will be to predict movie genres. We thus deal with a multilabel classification task.
|id||Action||Adventure||Animation||Comedy||Crime||Documentary||Drama||Family||Fantasy||Foreign||History||Horror||Music||Mystery||Romance||Science Fiction||TV Movie||Thriller||War||Western|
First I build a stacked Keras model with two components: A pretrained language model called
distilbert, basically a smaller version of the langauge model BERT, and a feed-forward neural network to handle the structured input data (ie. conventional features like ‘popularity’, ‘budget’, etc.). After these components are defined, we will use Concatenate() to combine the output layers of both submodels and add a classification head to it.
# Get pretrained language model with transformer architecture transformer_model = TFDistilBertModel.from_pretrained('distilbert-base-uncased') # define input layers for distilbert input_ids_in = tf.keras.layers.Input(shape=(512,), name='input_token', dtype='int32') input_masks_in = tf.keras.layers.Input(shape=(512,), name='masked_token', dtype='int32') # extract embedding layer embedding_layer = transformer_model(input_ids_in, attention_mask=input_masks_in) cls_token = embedding_layer[:,0,:] # add feed forward layers language_model = tf.keras.layers.BatchNormalization()(cls_token) language_model = tf.keras.layers.Dense(128, activation='relu')(language_model) language_model = tf.keras.layers.Dropout(0.2)(language_model) language_model = tf.keras.layers.Dense(64, activation='relu')(language_model) language_model = tf.keras.layers.Dense(32, activation='relu')(language_model) # Add numerical layer numerical_input = tf.keras.layers.Input(shape=(numerical_input.shape,), name='numerical_input', dtype='float64') numerical_layer = tf.keras.layers.Dense(20, input_dim=numerical_input.shape, activation='relu')(numerical_input)
After defining separate input and processing layers for the text and the numerical input, the output of these layers will be concatenated and feed to the classification head.
# Concatenate both layers concatted = tf.keras.layers.Concatenate()([language_model, numerical_layer]) # Add classification head with sigmoid activation head = tf.keras.layers.Dense(44, activation='relu')(concatted) head = tf.keras.layers.Dense(y_labels.shape, activation='sigmoid')(head) # define model with three input types, first two types will come from the tokenizer multimodal_model = tf.keras.Model(inputs=[input_ids_in, input_masks_in, numerical_input], outputs = head) # Prevent Distilbert from being trainable multimodal_model.layers.trainable = False
If we would want to incorporate image data into our multimodal model, we would simply add a model architecture (pretrained model or custom CNN) and also feed this into the concatenation layer.
The Model summary illustrates the flow through the network where data from both parts, the language model and the numerical input mlp, are passed to the classification head.
multimodal_model.summary() Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_token (InputLayer) [(None, 512)] 0 __________________________________________________________________________________________________ masked_token (InputLayer) [(None, 512)] 0 __________________________________________________________________________________________________ tf_distil_bert_model (TFDistilB TFBaseModelOutput(la 66362880 input_token masked_token __________________________________________________________________________________________________ tf.__operators__.getitem (Slici (None, 768) 0 tf_distil_bert_model __________________________________________________________________________________________________ batch_normalization (BatchNorma (None, 768) 3072 tf.__operators__.getitem __________________________________________________________________________________________________ dense (Dense) (None, 128) 98432 batch_normalization __________________________________________________________________________________________________ dropout_19 (Dropout) (None, 128) 0 dense __________________________________________________________________________________________________ dense_1 (Dense) (None, 64) 8256 dropout_19 __________________________________________________________________________________________________ numerical_input (InputLayer) [(None, 4)] 0 __________________________________________________________________________________________________ dense_2 (Dense) (None, 32) 2080 dense_1 __________________________________________________________________________________________________ dense_3 (Dense) (None, 20) 100 numerical_input __________________________________________________________________________________________________ concatenate (Concatenate) (None, 52) 0 dense_2 dense_3 __________________________________________________________________________________________________ dense_4 (Dense) (None, 44) 2332 concatenate __________________________________________________________________________________________________ dense_5 (Dense) (None, 20) 900 dense_4 ================================================================================================== Total params: 66,478,052 Trainable params: 113,636 Non-trainable params: 66,364,416 __________________________________________________________________________________________________
The nice thing with building a multimodal neural network with a framework like Keras is that all weights in all layers are trainable (although you should keep the weights of distiblBERT frozen) and may be optimized with the same loss function on the same task. So even if the model architecture becomes ymore complex, Keras will automatically handle backpropagation and weight updates for you.
fit on the model, the different inputs need to be passed to the model as a list.
# get training data input_ids, input_masks, numerical_input = preprocessing_multimodal_training_data() history = multimodal_model.fit( [input_ids, input_masks, numerical_input], y_labels, batch_size=BATCH_SIZE, validation_split=0.2, epochs=EPOCHS)
Using the AutoKeras API
Another, even simpler option for building multimodal models is to use Autokeras. The library offers a high-level wrapper for Keras and makes it very easy to define a model architecture with multimodal inputs. It provides a lot of other benefits, such as neural search algorithm to automatically tune the model architecture.
import autokeras as ak # define multi_label classification head head = ak.ClassificationHead( loss='categorical_crossentropy', multi_label=True, metrics=['accuracy']) # get training data text, title, numerical_input = preprocessing_autokeras_training_data() # Define multi_modal model with AutoModel Class ak_multimodal_model = ak.AutoModel( inputs=[ak.TextInput(), ak.StructuredDataInput()], # You may add as many input sources as you want. outputs=head, # use predefined head as output overwrite=True, max_trials=1) # Fit the model model.fit( [text, numerical_input], # pass two objects as input y_labels, # and multi_label matrix as input epochs=EPOCHS)
Multimodal models are easy
Whether you may want to use keras or autokeras, I hope you agree that multimodal models are easy. Autokeras makes it super simple to construct a multimodal model. Just call
ak.AutoModel and add two inputs. Using the base Keras API on the other hand will give you more flexibility to experiment with different model architectures and to incorporate other pretrained models. Whether or not a multimodal model will improve your prediction task should be carefully experimented and tested against simpler models, but from my experience it is definitely worth trying a multimodal approach, whenever there is enough data available.