Multimodal Retrieval: Learning Deep Supervised Representations with Online Hard Triplet Mining
Finetuning a modified version of OpenAI’s CLIP to leverage already aligned image and text encoders for accurate, jointly learned, multimodal data representations
Representation learning has gained increasing attention due to its broad applicability in modern information retrieval systems. Its utility is undeniable, from facial recognition to the latest retrieval-augmented generation (RAG) pipelines — an emerging method for enhancing generative AI models by supplying relevant context. The integration of multimodal data plays a crucial role, as different data types can complement one another to map subjects in an embedding space more accurately.
We can produce representations that capture even the slightest details by fine-tuning pre-trained encoders with supervised contrastive learning using batch hard triplet losses. This approach enables models to differentiate between instances from different classes while pulling together distant samples within the same class, as defined by a predefined distance function.
In this article, I’ll give you an intuitive understanding of supervised multimodal representation learning with online mined hard triplets. Then, we’ll fine-tune OpenAI’s already aligned CLIP image and text encoders to build a specialized multimodal movie retrieval system, learning detailed representations for movie plot outlines paired with poster images, based on annotated genres. During the demo, I also want to focus on how easily this can be done with a library that I open-sourced, called SuperTriplets.
Concepts Intuition
Multimodal data
Multimodal data refers to data that comes in multiple modalities or formats, such as images, text, audio, and video. In many real-world scenarios, the information available about an object or concept can be more comprehensive when combining different modalities. For example, an image of a dish may be accompanied by a description of its ingredients and cooking instructions, enhancing the understanding of it.
Despite the prevalence of multimodal data, a lot about modeling it is yet to be well established since it poses intricate challenges — merging diverse data types like images and text into one cohesive representation space demands handling differences in data distribution, scale, and modality-specific meanings. Complex relationships between words and pictures, coupled with language and visual interpretation variations, require advanced techniques.
Supervised contrastive learning
Supervised contrastive learning uses contrastive principles to help the network differentiate between classes. By pairing each data sample with other samples from the same class and contrasting them with samples from different classes, the network learns to construct more meaningful and semantically rich embeddings.
Researchers have demonstrated its effectiveness across various domains, such as computer vision and natural language processing, achieving state-of-the-art results on tasks that require specialized representation learning, such as content retrieval and face matching. It also contributes to exploring feature representations and advancing deep learning models in supervised settings.
Online hard triplet learning
One powerful approach to achieving state-of-the-art embeddings with contrastive learning is online/batch hard triplet mining. The concept of triplets is central to this technique. A triplet consists of three samples: an anchor, a positive, and a negative instance. In the supervised mode, the anchor and the positive are from the same class.
The goal of a model updated by a triplet loss is to correctly discriminate between positive and negative instances while also ensuring that the embeddings of positive instances are closer together than those of negative instances. However, randomly selecting triplets during training can lead to slow convergence and suboptimal results.
To address this, we employ online hard triplet mining to select the hardest triplets during each training iteration dynamically. This focuses the training process on the most informative and challenging instances, leading to more robust and fine-grained learned representations.
As expected, a triplet loss is applied to minimize the distance between the anchor and the positive sample, while maximizing the distance between the anchor and the negative sample. An important consideration is the margin parameter, which enforces that the positive sample must be closer to the anchor than the negative sample by at least a given margin. Without this margin, the network could output zero for all embeddings and incur little to no penalty.
OpenAI CLIP
CLIP (Contrastive Language-Image Pre-Training) is an open-source deep learning model introduced by OpenAI in early 2021 that aims to learn visual concepts with natural language supervision, taking a huge leap forward in zero-shot image classification and representation.
Intuitively, it is pre-trained to learn to align an image encoder and a text encoder to produce near embeddings in a cosine similarity space for an image of a certain object and its description via contrastive learning. With that, CLIP is capable of zero-shot image classification on unseen data, given that we provide the labels of the possible classes that should be recognized.
Note that CLIP doesn’t use multimodality to encode an entity but rather tries its best to align the image and sentence embeddings.
Building a Multimodal Movie Retrieval System
Let’s see a quick demonstration of how to apply all that knowledge. We’ll fine-tune CLIP-based embeddings for pairs of movie cover photos and plot outlines to build a multimodal movie search engine.
The core tech used here is SuperTriplets, PyTorch, FAISS, and Streamlit. The code and instructions to reproduce everything presented here are on GitHub. Before starting, I’ll detail some of the most relevant components of the proposed demo.
Demo components
Tiny MM-IMDb — MM-IMDb is the largest publicly available multimodal dataset for movie genre prediction starting from the movies of the MovieLens 20M dataset and expanding by collecting genre, poster, and plot information for each movie. Inspired by MM-IMDb, I’ve built Tiny MM-IMDb. It takes a subset of movies and information available there focusing on modeling. We are going to use it in this demo.
CLIP-based neural network architecture: we are going to use CLIP image and text encoders, but since we want to learn multimodal embeddings, we will simply add each modality representation to get a single embedding for each sample.
Supervised hard triplet learning with SuperTriplets — SuperTriplets is an open-source toolbox for supervised online hard triplet learning. It doesn’t try to automate the training and evaluation loop for us. Instead, it provides useful PyTorch-based utilities we can couple to existing code, making the process as simple as performing other everyday supervised learning tasks, such as classification and regression.
Facebook AI Similarity Search (FAISS)—FAISS is a solution from Meta AI that makes it easy to build approximate nearest neighbor searchable indexes. We’ll use it to index our movie embeddings.
Streamlit — with Streamlit we can build a web app in a few lines of pure Python code. We are going to use it to build the user interface with our retrieval system.
Setting up the dev environment
Make sure you have python3 and venv installed on your machine. Ideally, a GPU should be available as well.
Clone the repo, create a virtual environment, and activate it:
$ git clone https://github.com/gabrieltardochi/multimodal-movie-retrieval-pytorch.git
$ cd multimodal-movie-retrieval-pytorch/
$ python3 -m venv .venv
$ source .venv/bin/activate # linux
Finally, install project dependencies:
$ pip install -r requirements.txt
Acquiring data
First of all, download and unzip the Tiny MM-IMDb dataset from Kaggle.
Now, create a .env
file making sure that TINY_MMIMDB_DATASET_PATH
is defined. Here’s how the .venv should be looking like so far:
TINY_MMIMDB_DATASET_PATH=<SAVED DATASET PATH HERE> # e.g. /home/gabriel/datasets/tinymmimdb
Finetuning CLIP-based multimodal encoder
The training script is located in the movie_retrieval.train
module. Here I’ll focus on the core training function and how it differs from the usual pytorch training code.
The imports are the following:
import argparse
import json
import os
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
import matplotlib.pyplot as plt
import pandas as pd
import torch
from dotenv import load_dotenv
from supertriplets.dataset import OnlineTripletsDataset, StaticTripletsDataset
from supertriplets.distance import EuclideanDistance
from supertriplets.encoder import PretrainedSampleEncoder
from supertriplets.evaluate import HardTripletsMiner, TripletEmbeddingsEvaluator
from supertriplets.loss import BatchHardTripletLoss
from supertriplets.models import load_pretrained_model
from supertriplets.sample import TextImageSample
from supertriplets.utils import move_tensors_to_device
from torch.utils.data import DataLoader
from tqdm import tqdm
from movie_retrieval.config import TINY_MMIMDB_DATASET_PATH
from movie_retrieval.utils import prepare_tinymmimdb
Now, let’s zoom in on the core train()
function and its parameters. We leverage argparse
and a TrainingConfig
class to configure the training loop:
@dataclass(frozen=True)
class TrainingConfig:
in_batch_num_samples_per_label: int
batch_size: int
learning_rate: float
num_epochs: int
def train(cfg: TrainingConfig) -> None:
...
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Training configuration parser")
parser.add_argument(
"--in_batch_num_samples_per_label",
type=int,
default=2,
help="Number of samples per label in a batch (default: 2)",
)
parser.add_argument(
"--batch_size", type=int, default=32, help="Size of each batch (default: 32)"
)
parser.add_argument(
"--learning_rate",
type=float,
default=2e-5,
help="Learning rate for training (default: 2e-5)",
)
parser.add_argument(
"--num_epochs",
type=int,
default=2,
help="Number of epochs for training (default: 2)",
)
args = parser.parse_args()
training_config = TrainingConfig(
in_batch_num_samples_per_label=args.in_batch_num_samples_per_label,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
num_epochs=args.num_epochs,
)
train(cfg=training_config) # start training
Now to the training loop. We start by reading the data and instantiating supertriplets.sample.TextImageSample
objects for each datapoint in our training, validation, and test splits:
def train(cfg: TrainingConfig) -> None:
# setup experiment saving config
current_datetime = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d-%H-%M-%S-%f")
save_path = os.path.join("artifacts", current_datetime)
os.makedirs(save_path, exist_ok=False)
# read dataset splits
train = prepare_tinymmimdb_split("train")
valid = prepare_tinymmimdb_split("dev")
test = prepare_tinymmimdb_split("test")
# might take too long on the cpu, make sure you have a gpu available
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# our data points are pairs of (poster image, plot outline) labeled with the movie genre identifier
train_examples = [
TextImageSample(text=text, image_path=image_path, label=label)
for text, image_path, label in zip(
train["plot_outline"], train["image_path"], train["genre_id"]
)
]
valid_examples = [
TextImageSample(text=text, image_path=image_path, label=label)
for text, image_path, label in zip(
valid["plot_outline"], valid["image_path"], valid["genre_id"]
)
]
test_examples = [
TextImageSample(text=text, image_path=image_path, label=label)
for text, image_path, label in zip(
test["plot_outline"], test["image_path"], test["genre_id"]
)
]
...
While in the training loop, each sample can be combined into different triplets, since the hardest triplets are “mined” every batch (online). But the same can’t be done for validation and test datasets — we need them fixed to measure model performance before, during, and after training.
To overcome this, we use a pre-trained model from supertriplets.encoder.PretrainedSampleEncoder
to encode validation and test movies, then we use supertriplets.evaluate.HardTripletsMiner
to extract hard validation and test triplets:
def train(cfg: TrainingConfig) -> None:
...
# load a pretrained encoder just to help find hard triplets in valid and test set
pretrained_encoder = PretrainedSampleEncoder(modality="text_english-image")
valid_embeddings = pretrained_encoder.encode(
examples=valid_examples, device=device, batch_size=cfg.batch_size
)
test_embeddings = pretrained_encoder.encode(
examples=test_examples, device=device, batch_size=cfg.batch_size
)
del pretrained_encoder
# this will efficiently find hard positives and negatives from a number of labeled samples
hard_triplet_miner = HardTripletsMiner(use_gpu_powered_index_if_available=True)
(
valid_anchor_examples,
valid_positive_examples,
valid_negative_examples,
) = hard_triplet_miner.mine(
examples=valid_examples,
embeddings=valid_embeddings,
normalize_l2=True,
sample_from_topk_hardest=10,
)
(
test_anchor_examples,
test_positive_examples,
test_negative_examples,
) = hard_triplet_miner.mine(
examples=test_examples,
embeddings=test_embeddings,
normalize_l2=True,
sample_from_topk_hardest=10,
)
del hard_triplet_miner
Finally, we are ready to start the model for further fine-tuning:
def train(cfg: TrainingConfig) -> None:
...
# init model for finetuning, any torch nn would work
model = load_pretrained_model(model_name="CLIPViTB32EnglishEncoder")
model.to(device)
model.eval()
We are leveraging a torch.nn
implemented by SuperTriplets. It might be useful to have a better look at it, you can also find its implementation on GitHub:
from typing import Dict, Optional, Union
import torch
from PIL import Image
from sentence_transformers import SentenceTransformer
from torch import Tensor, nn
from transformers import (
AutoConfig,
AutoTokenizer,
CLIPImageProcessor,
CLIPModel,
CLIPTokenizer,
)
class CLIPViTB32EnglishEncoder(nn.Module):
def __init__(self):
super().__init__()
self.hf_image_and_text_model_name = "openai/clip-vit-base-patch32"
self.image_and_text_config = AutoConfig.from_pretrained(self.hf_image_and_text_model_name)
self.processor = CLIPImageProcessor.from_pretrained(
self.hf_image_and_text_model_name, config=self.image_and_text_config
)
self.tokenizer = CLIPTokenizer.from_pretrained(
self.hf_image_and_text_model_name, config=self.image_and_text_config
)
clip_image_and_text = CLIPModel.from_pretrained(
self.hf_image_and_text_model_name, config=self.image_and_text_config
)
self.vision_model = clip_image_and_text.vision_model
self.visual_projection = clip_image_and_text.visual_projection
self.text_model = clip_image_and_text.text_model
self.textual_projection = clip_image_and_text.text_projection
def load_input_example(
self,
text: str,
image_path: str,
label: int,
return_tensors: str = "pt",
truncation: bool = True,
padding: str = "max_length",
max_length: Optional[int] = None,
) -> Dict[str, Union[Dict[str, Tensor], Tensor]]:
text_input = self.tokenizer(
text,
return_tensors=return_tensors,
truncation=truncation,
padding=padding,
max_length=max_length,
)
text_input = {k: v.squeeze() for k, v in text_input.items()}
image_input = self.processor(Image.open(image_path), return_tensors=return_tensors).data
image_input["pixel_values"] = image_input["pixel_values"].squeeze()
label = torch.tensor(label)
return {"text_input": text_input, "image_input": image_input, "label": label}
def forward(self, image_input: Dict[str, Tensor], text_input: Dict[str, Tensor]) -> Tensor:
image_features = self.get_image_features(image_input)
text_features = self.get_text_features(text_input)
multimodal_features = image_features + text_features
return multimodal_features.squeeze()
def get_image_features(self, image_input: Dict[str, Tensor]) -> Tensor:
vision_pooled_output = self.vision_model(**image_input)[1] # pooled_output
image_features = self.visual_projection(vision_pooled_output)
return image_features
def get_text_features(self, text_input: Dict[str, Tensor]) -> Tensor:
text_token_embeddings = self.text_model(**text_input)[0] # all token embeddings
text_pooled_output = self.text_mean_pooling(text_token_embeddings, text_input["attention_mask"])
text_features = self.textual_projection(text_pooled_output)
return text_features
@staticmethod
def text_mean_pooling(token_embeddings: Tensor, attention_mask: Tensor) -> Tensor:
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
We can see that both modality encoders require preprocessing with load_input_example()
, then on the forward()
pass we apply the combination of pooling and a linear layer projection for each so that at the end we simply add both to get our embeddings.
Now, to the creation of our datasets and dataloaders:
def train(cfg: TrainingConfig) -> None:
...
# train dataset will have its hard triplets mined for each training batch,
# its a torch.utils.data.IterableDataset implementation
trainset = OnlineTripletsDataset(
examples=train_examples,
in_batch_num_samples_per_label=cfg.in_batch_num_samples_per_label,
batch_size=cfg.batch_size,
sample_loading_func=model.load_input_example,
sample_loading_kwargs={},
)
# validset and testset triplets were mined earlier, keep it static to measure performance improvements,
# its a torch.utils.data.Dataset implementation
validset = StaticTripletsDataset(
anchor_examples=valid_anchor_examples,
positive_examples=valid_positive_examples,
negative_examples=valid_negative_examples,
sample_loading_func=model.load_input_example,
sample_loading_kwargs={},
)
testset = StaticTripletsDataset(
anchor_examples=test_anchor_examples,
positive_examples=test_positive_examples,
negative_examples=test_negative_examples,
sample_loading_func=model.load_input_example,
sample_loading_kwargs={},
)
# init dataloaders
trainloader = DataLoader(
dataset=trainset,
batch_size=cfg.batch_size,
num_workers=0,
drop_last=True,
)
validloader = DataLoader(
dataset=validset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=0,
drop_last=False,
)
testloader = DataLoader(
dataset=testset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=0,
drop_last=False,
)
SuperTriplet’s supertriplets.dataset.OnlineTripletsDataset
and supertriplets.dataset.StaticTripletsDataset
are implementations of torch.utils.data.IterableDataset
and torch.utils.data.Dataset
, respectively.
Before training, we should measure the initial performance of the model on the validation and test datasets. We can do this with supertriplets.evaluate.TripletEmbeddingsEvaluator
, as it can calculate the “accuracy” (here being accurate means distance[anchor, positive] < distance[anchor, negative]) for a variety of distance metrics. There is also a utility function get_triplet_embeddings()
omitted for simplicity, to help us encode samples:
def train(cfg: TrainingConfig) -> None:
...
# calculate initial anchor, positive and negative embeddings for valid and test datasets
valid_triplet_embeddings = get_triplet_embeddings(
dataloader=validloader, model=model, device=device
)
test_triplet_embeddings = get_triplet_embeddings(
dataloader=testloader, model=model, device=device
)
# init evaluator, abstracts evaluating triplet embeddings using different distance metrics
triplet_embeddings_evaluator = TripletEmbeddingsEvaluator(
calculate_by_cosine=True,
calculate_by_manhattan=True,
calculate_by_euclidean=True,
)
# getting initial valid accuracy (means distance[anchor, positive] < distance[anchor, negative])
valid_start_accuracy = triplet_embeddings_evaluator.evaluate(
embeddings_anchors=valid_triplet_embeddings["anchors"],
embeddings_positives=valid_triplet_embeddings["positives"],
embeddings_negatives=valid_triplet_embeddings["negatives"],
)
There are two online triplet losses implemented within supertriplets.loss
— BatchHardTripletLoss
and BatchHardSoftMarginTripletLoss.
The first forces the network to map similar inputs closer in the embedding space and dissimilar inputs farther apart by selecting the hardest positive and hardest negative examples per batch for each anchor with a parametrized margin by which the positive distance should be less than the negative distance in order to incur a loss, while the last uses a soft-margin formulation.
With that out of the way, let’s initialize our optimizer and criterion:
def train(cfg: TrainingConfig) -> None:
...
# init optimizer
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
],
"weight_decay": 0.01,
},
{
"params": [
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate)
# init triplet loss
criterion = BatchHardTripletLoss(
distance=EuclideanDistance(squared=False), margin=5
)
Finally, the training loop. It should look pretty straightforward to PyTorch developers. The only detail is that we are using the utility function supertriplets.utils.move_tensors_to_device()
to move our tensors from the CPU to the GPU:
def train(cfg: TrainingConfig) -> None:
...
# init metrics tracking
train_loss_progress = []
valid_accuracy_progress = [valid_start_accuracy]
# training loop
for epoch in range(1, cfg.num_epochs + 1):
model.train()
for batch in tqdm(trainloader, total=len(trainloader), desc=f"Epoch {epoch}"):
data = batch["samples"]
labels = move_tensors_to_device(obj=data.pop("label"), device=device)
inputs = move_tensors_to_device(obj=data, device=device)
optimizer.zero_grad()
embeddings = model(**inputs)
loss = criterion(embeddings=embeddings, labels=labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()
train_loss_progress.append(loss.item())
# epoch valid metrics
model.eval()
valid_triplet_embeddings = get_triplet_embeddings(
dataloader=validloader, model=model, device=device
)
valid_accuracy = triplet_embeddings_evaluator.evaluate(
embeddings_anchors=valid_triplet_embeddings["anchors"],
embeddings_positives=valid_triplet_embeddings["positives"],
embeddings_negatives=valid_triplet_embeddings["negatives"],
)
valid_accuracy_progress.append(valid_accuracy)
test_triplet_embeddings = get_triplet_embeddings(
dataloader=testloader, model=model, device=device
)
test_final_accuracy = triplet_embeddings_evaluator.evaluate(
embeddings_anchors=test_triplet_embeddings["anchors"],
embeddings_positives=test_triplet_embeddings["positives"],
embeddings_negatives=test_triplet_embeddings["negatives"],
)
To finish, we save the model, evaluation plots, and metrics:
def train(cfg: TrainingConfig) -> None:
...
training_stats = {
"train_batch_triplet_loss": train_loss_progress,
"valid_epoch_accuracy_cosine": [
acc["accuracy_cosine"] for acc in valid_accuracy_progress
],
"valid_epoch_accuracy_euclidean": [
acc["accuracy_euclidean"] for acc in valid_accuracy_progress
],
"valid_epoch_accuracy_manhattan": [
acc["accuracy_manhattan"] for acc in valid_accuracy_progress
],
"test_final_accuracy_cosine": test_final_accuracy["accuracy_cosine"],
"test_final_accuracy_euclidean": test_final_accuracy["accuracy_euclidean"],
"test_final_accuracy_manhattan": test_final_accuracy["accuracy_manhattan"],
}
# save training config
json.dump(asdict(cfg), open(os.path.join(save_path, "training-config.json"), "w"))
# save model weights
torch.save(model.state_dict(), os.path.join(save_path, "model_state_dict.pt"))
# plot results
plot_results(
training_stats=training_stats,
save_path=os.path.join(save_path, "training-stats.png"),
)
# serialize training metrics into file
json.dump(training_stats, open(os.path.join(save_path, "training-stats.json"), "w"))
The result should look like this:
Encoding and indexing movie embeddings
The indexing script is located in the movie_retrieval.index
module.
First of all, add the TRAINED_MODEL_STATE_DICT_PATH
variable to the .env
file. It should be a path to the trained model weights. Here’s how the .venv should be looking like so far:
TINY_MMIMDB_DATASET_PATH=<SAVED DATASET PATH HERE> # e.g. /home/gabriel/datasets/tinymmimdb
TRAINED_MODEL_STATE_DICT_PATH=<SAVED STATE DICT PATH HERE> # e.g. artifacts/2024-09-16-10-09-07-877894/model_state_dict.pt
The script is actually simple, we need to:
- Load the model;
- Encode the movies with it;
- Apply L2 normalization to the embeddings;
- Build a searchable index from the embeddings;
- Save to disk.
Here’s how it might look like in code (go check movie_retrieval.utils
file for more info):
import faiss
import numpy as np
import pandas as pd
import torch
from faiss import write_index
from supertriplets.sample import TextImageSample
from movie_retrieval.config import (
TINY_MMIMDB_DATASET_PATH,
TRAINED_MODEL_STATE_DICT_PATH,
)
from movie_retrieval.utils import (
get_embeddings_from_multimodal_model,
load_trained_multimodal_model,
prepare_tinymmimdb,
)
if __name__ == "__main__":
# load mmimdb
df = prepare_tinymmimdb(dataset_path=TINY_MMIMDB_DATASET_PATH)
# might take too long on the cpu, make sure you have a gpu available
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# load model
model = load_trained_multimodal_model(
state_dict_path=TRAINED_MODEL_STATE_DICT_PATH, device=device
)
# run model
embeddings = get_embeddings_from_multimodal_model(
examples=[
TextImageSample(text=text, image_path=image_path, label=-1)
for text, image_path in zip(df["plot_outline"], df["image_path"])
],
model=model,
device=device,
)
# creating FAISS index
embed_dim = embeddings[0].size
db_vectors = embeddings.copy().astype(np.float32)
db_ids = df["movie_id"].values.astype(np.int64)
faiss.normalize_L2(db_vectors)
index = faiss.IndexFlatIP(embed_dim)
index = faiss.IndexIDMap(index) # mapping 'movie_id' as id
index.add_with_ids(db_vectors, db_ids)
# saving
write_index(index, "movie-retrieval.index")
Building a basic retrieval interface for the system
The application script is located in the app.py
file. It essentially does the following whenever someone clicks on the search button:
- Encode movie poster and plot pair;
- Apply L2 normalization to the embeddings;
- Search for the most similar movies within the index;
- Returns the result.
Our application code is the following:
import os
import tempfile
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch
from faiss import read_index
from supertriplets.sample import TextImageSample
from movie_retrieval.config import (
TINY_MMIMDB_DATASET_PATH,
TRAINED_MODEL_STATE_DICT_PATH,
)
from movie_retrieval.utils import (
get_embeddings_from_multimodal_model,
load_trained_multimodal_model,
prepare_tinymmimdb,
)
# load mmimdb
df = prepare_tinymmimdb(dataset_path=TINY_MMIMDB_DATASET_PATH)
df.set_index("movie_id", inplace=True)
# prefer gpu
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# load model
model = load_trained_multimodal_model(
state_dict_path=TRAINED_MODEL_STATE_DICT_PATH, device=device
)
# load movie retrieval index
index = read_index("movie-retrieval.index")
# streamlit stuff
st.title("MultiModal Movie Retrieval :movie_camera:")
st.text("https://github.com/gabrieltardochi/multimodal-movie-retrieval-pytorch")
st.subheader("", divider="rainbow")
# input fields for image URL, format, and name
poster_image = st.file_uploader("Upload poster image (png/jpeg)")
plot_outline = st.text_area("Enter movie plot outline", max_chars=500, height=250)
if st.button("Search"):
# encode (poster image, plot outline) pair
with tempfile.TemporaryDirectory() as tmpdirname:
poster_image_bytes = poster_image.getvalue()
poster_path = os.path.join(tmpdirname, "img.jpg")
with open(poster_path, "wb") as file:
file.write(poster_image_bytes)
# encode input
embeddings = get_embeddings_from_multimodal_model(
examples=[
TextImageSample(text=plot_outline, image_path=poster_path, label=-1)
],
model=model,
device=device,
)
# searching with input embeddings
find_k_most_similar = 3
faiss.normalize_L2(embeddings)
similarities, similarities_ids = index.search(embeddings, k=find_k_most_similar)
similarities = similarities[0] # we are only searching with a single datapoint
similarities_ids = similarities_ids[
0
] # we are only searching with a single datapoint
similarities = np.around(np.clip(similarities, 0, 1), decimals=4)
df.loc[similarities_ids, "similarity_score"] = similarities
# build output
results = {}
for k in range(find_k_most_similar):
results[f"top{k+1}-similar-retrieved"] = df.loc[similarities_ids[k]].to_dict()
# return
st.success("Retrieved successfully.")
st.json(
results,
expanded=2,
)
We can now run $ streamlit run app.py
to see our app. It should look like the following:
We are now ready to use our system. Let’s see its response to the search of Spider-Man: No Way Home:
After clicking on “Search”, the results are displayed:
Conclusion
Wrapping up, the combination of multimodal representation learning with online hard triplet mining showcases just how effective deep learning can be in capturing complex, cross-modal relationships.
By leveraging SuperTriplets to fine-tune CLIP with supervised contrastive learning, we created a vector space where text and image data come together in a way that reveals subtle yet crucial distinctions between categories. This approach isn’t just about improved accuracy; it’s about building systems that genuinely “get” the context behind each piece of data, which is pivotal in nuanced applications.
And that’s it! Hopefully, I’ve helped you somehow in your deep representation learning journey. I truly believe that this field will play a big role in how we can reach the most intelligent systems we will ever see.