Tensorflow Recommenders 4 - Multi-task learning: How to train retrieval and ranking models together?



In the previous tutorials, we have looked at retrieval model, ranking model, and leveraging contextual features for those models. In this post, we look at how to train the two models (retrieval and ranking) together as a multi-task learning problem.

Content

  • Prepare dataset
  • User model
  • Movie model
  • Movielens model
  • Train and evaluate the model

Prepare dataset

First, let's import packages that we need, and print out Tensorflow and TFRS versions for reference.

from typing import Dict, Text # for typing hint

import pprint
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_recommenders as tfrs

print(tf.__version__)
print(tfrs.__version__)
Output:

2.9.1
v0.7.0

Load the Movielens 100k dataset.

# Load the movielens dataset
ratings = tfds.load('movielens/100k-ratings', split='train')
ratings = ratings.map(lambda x: {
    'movie_title': x['movie_title'],
    'user_id': x['user_id'],
    'user_rating': x['user_rating'],
    'timestamp': x['timestamp']
})

movies = tfds.load('movielens/100k-movies', split='train')
movies = movies.map(lambda x: x['movie_title'])

timestamps = np.concatenate(list(
    ratings.map(lambda x: x['timestamp']).batch(100)))
max_timestamp = timestamps.max()
min_timestamp = timestamps.min()

timestamp_buckets = np.linspace(
    min_timestamp, max_timestamp, num=1000)

unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))
unique_user_ids = np.unique(np.concatenate(list(ratings.batch(1_000).map(
    lambda x: x['user_id']))))

print(len(unique_movie_titles), len(unique_user_ids))
Output:

1664 943

User Model

We use UserModel as the same one defined in the previous tutorial in which we can optionally choose to use the timestamps as part of our user model.

class UserModel(tf.keras.Model):
    # User embedding will be user id + ts + normalized ts embeddings

    def __init__(self, use_timestamps):
        super().__init__()
        
        self._use_timestamps = use_timestamps
        
        # User id embedding
        self.user_embedding = tf.keras.Sequential([
            tf.keras.layers.StringLookup(
                vocabulary=unique_user_ids,
                mask_token=None),
            tf.keras.layers.Embedding(len(unique_user_ids)+1, 32)
        ])
        
        if use_timestamps:
            # Use timestamp
            self.timestamp_embedding = tf.keras.Sequential([
                tf.keras.layers.Discretization(timestamp_buckets.tolist()),
                tf.keras.layers.Embedding(len(timestamp_buckets)+1, 32)
            ])

            # Normalized timestamp
            self.normalized_timestamp = tf.keras.layers.Normalization(axis=None)
            self.normalized_timestamp.adapt(timestamps)
        
    
    def call(self, inputs):
        if not self._use_timestamps:
            return self.user_embedding(inputs['user_id'])
        
        return tf.concat([
            self.user_embedding(inputs['user_id']),
            self.timestamp_embedding(inputs['timestamp']),
            tf.reshape(self.normalized_timestamp(inputs['timestamp']),(-1,1))
        ], axis=1) 

Movie Model

Similarly, we use the MovieModel as the same one we used in the previous tutorial in which we optionally use the title text as part of the movie embeddings.

class MovieModel(tf.keras.Model):
    # Movie embedding: title text + id 
    
    def __init__(self, use_title_text):
        super().__init__()
        max_tokens = 10_000
        
        self._use_title_text = use_title_text
        
        self.title_embedding = tf.keras.Sequential([
            tf.keras.layers.StringLookup(
                vocabulary=unique_movie_titles, mask_token=None),
            tf.keras.layers.Embedding(len(unique_movie_titles)+1, 32)
        ])
        
        if use_title_text:
            self.title_vectorizer = tf.keras.layers.TextVectorization(
                max_tokens=max_tokens)
            self.title_vectorizer.adapt(movies)

            self.title_text_embedding = tf.keras.Sequential([
                self.title_vectorizer,
                tf.keras.layers.Embedding(max_tokens, 32, mask_zero=True),
                tf.keras.layers.GlobalAveragePooling1D()
            ])
        
        
    def call(self, inputs):
        if not self._use_title_text:
            return self.title_embedding(inputs)
        
        return tf.concat([
            self.title_embedding(inputs),
            self.title_text_embedding(inputs)
        ], axis=1)

Movielens Model

So far, the UserModel and MovieModel are exactly the same as in the previous tutorial, and nothing new. Now we move on to define our new MovielensModel which allows us to train both the retrieval and ranking tasks together in a multi-task training scheme. 

We can see that two tasks are defined in the __init__() method, and in the compute_loss() we are calculating the loss as the total of both tasks with equal contribution (with their corresponding weights self.rating_weight and self.retrieval_weight as 0.5 respectively).

class MovielensModel(tfrs.models.Model):
    
    def __init__(self, use_timestamps=True, use_title_text=True):
        super().__init__()
        
        self.rating_weight = 0.5
        self.retrieval_weight = 0.5
        
        # User and Movie models
        self.user_model = tf.keras.Sequential([
            UserModel(use_timestamps),
            tf.keras.layers.Dense(32)
        ])
        self.movie_model = tf.keras.Sequential([
            MovieModel(use_title_text),
            tf.keras.layers.Dense(32)
        ])
        
        # Ranking model
        self.rating_model = tf.keras.Sequential([
            # Multiple dense layers
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(64, activation='relu'),
            # Prediction layer
            tf.keras.layers.Dense(1)
        ])
    
        # Multi-tasks
        self.rating_task: tf.keras.layers.Layer = tfrs.tasks.Ranking(
            loss=tf.keras.losses.MeanSquaredError(),
            metrics=[tf.keras.metrics.RootMeanSquaredError()]
        )
        self.retrieval_task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(
            metrics=tfrs.metrics.FactorizedTopK(
                candidates=movies.batch(128).map(self.movie_model)
            )
        )
            
    def call(self, features: Dict[Text, tf.Tensor]) -> tf.Tensor:
        user_embeddings = self.user_model({
            'user_id': features['user_id'],
            'timestamp': features['timestamp']
        })
        movie_embeddings = self.movie_model(
            features['movie_title']
        )
        return (
            user_embeddings, 
            movie_embeddings,
            self.rating_model(tf.concat([
                user_embeddings,
                movie_embeddings
            ], axis=1))
        )
        
    def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
        user_embeddings, movie_embeddings, rating_predictions = self.call(features)
        # Retrieval loss
        retrieval_loss = self.retrieval_task(user_embeddings, movie_embeddings)
        # Rating loss
        rating_loss = self.rating_task(
            labels=features['user_rating'],
            predictions=rating_predictions
        )
        
        # Combine two losses with hyper-parameters (to be tuned)
        return (self.rating_weight * rating_loss \
                + self.retrieval_weight * retrieval_loss)





Train and evaluate the model

We use 80% of the dataset for training, and the rest (20%) for testing.

# -------------------------------
# Experiment
# -------------------------------
# Prepare data
tf.random.set_seed(7)
shuffled = ratings.shuffle(100_000, seed=7,
                reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)

cached_train = train.shuffle(100_000).batch(2048).cache()
cached_test = test.batch(4096).cache()

model = MovielensModel(use_timestamps=True, use_title_text=True)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

model.fit(cached_train, epochs=3)

train_acc = model.evaluate(
    cached_train, return_dict=True)['factorized_top_k/top_100_categorical_accuracy']
test_acc = model.evaluate(
    cached_test, return_dict=True)['factorized_top_k/top_100_categorical_accuracy']

print(f'Top-100 accuracy (train): {train_acc:.2f}')
print(f'Top-100 accuracy (test): {test_acc:.2f}')
Output:

Top-100 accuracy (train): 0.34
Top-100 accuracy (test): 0.25
Evaluation results of all metrics on the test set.

model.evaluate(cached_test, return_dict=True)
Output:

{'root_mean_squared_error': 1.0596544742584229,
 'factorized_top_k/top_1_categorical_accuracy': 0.00139999995008111,
 'factorized_top_k/top_5_categorical_accuracy': 0.011549999937415123,
 'factorized_top_k/top_10_categorical_accuracy': 0.025450000539422035,
 'factorized_top_k/top_50_categorical_accuracy': 0.1335500031709671,
 'factorized_top_k/top_100_categorical_accuracy': 0.24815000593662262,
 'loss': 13961.8505859375,
 'regularization_loss': 0,
 'total_loss': 13961.8505859375}

We can get the 5 movies in the test set for user 42 with a certain timestamp, and sort them based on their scores in a descending order.

test_ratings = {}
for m in test.take(5):
#     print(m['movie_title'].numpy())
    _, _, test_ratings[m['movie_title'].numpy()] = \
        model(
            {'user_id':np.array(['42']), 
             'timestamp':np.array([892839492]), 
             'movie_title': np.array([m['movie_title'].numpy()])
            }
        )
    
for m in sorted(test_ratings, key=test_ratings.get, reverse=True):
    print(m)
Output:

b'Chasing Amy (1997)'
b'Top Gun (1986)'
b'Twister (1996)'
b'Event Horizon (1997)'
b'Batman Forever (1995)'

We can look into the scores of those movies in test_ratings

for r in test_ratings: 
    print(r, test_ratings[r].numpy()[0][0])
Output:

b'Top Gun (1986)' 3.278518
b'Chasing Amy (1997)' 3.3716574
b'Batman Forever (1995)' 2.8443925
b'Twister (1996)' 3.1891446
b'Event Horizon (1997)' 3.0244293

Finally, let's change the timestamp of user 42 for those movies to check if the predicted scores/ratings change. When we change the timestamp from 892839492 to 879024327, we can see the ratings change accordingly as below: Output:

b'Top Gun (1986)' 3.3720827
b'Chasing Amy (1997)' 3.408927
b'Batman Forever (1995)' 3.0349593
b'Twister (1996)' 3.162705
b'Event Horizon (1997)' 3.058013

More TFRS tutorials can be found at https://parklize.blogspot.com/p/tensorflow.html

References

No comments:

Post a Comment