Chapter 1. Introduction

Recommendation Systems are integral to the development of the internet that we know today, and are what we think of as a central function of emerging technology companies. Beyond the search ranking that opened the web’s breadth to everyone, the new and exciting movies all your friends are watching, or the most relevant ads that companies pay top-dollar to show you, lie more applications of recommendation systems every year. The addictive for-you-page from Tiktok, the Discover Weekly by Spotify, Board Suggestions on Pinterest, and Apple’s App Store are all hot technologies enabled by the recommendation systems of today. These days, sequential transformer models, multi-modal representations, and graph neural nets are among the brightest areas of R&D in Machine Learning – all being put to use in Recommendation Systems.

Ubiquity of any technology often prompts questions of how the technology works, why it has become so common, and if one can get in on the action. For recommendation systems, the how is quite complicated. We’ll need to understand the geometry of taste, and how only a little bit of interaction from a user can provide us a GPS signal in that abstract space. You’ll see how to quickly gather a great set of candidates, and how to refine them to a cohesive set of recommendations. Finally, you’ll learn how to evaluate your recommender, build the endpoint that serves inference, and log about it’s behavior.

We will formulate a number a variants of the core problem of recommendation systems, but at ultimately the motivating problem framing is:

Key components of a RecSys

As we increase complexity and sophistication, let’s keep in mind what the components of our system are. We will use what are called string diagrams to keep track of the different components, but in the literature a variety of presentations of these diagrams appear.

We will identify and build on three core components of recommendation systems:

Collector

The collector’s role is to know what is in the collection of things that may be recommended, and the necessary features or attributes of those things. Note that this collection often is a subset based on context or state.

Ranker

The ranker’s role is to take the collection provided by the collector, and order some or all of them, according to a model for the context and user.

Server

The server’s role is to take the ordered subset provided by the ranker, ensure that the necessary data schema is satisfied–including essential business logic–and return the requested number of recommendations.

Take, for example, a hospitality scenario with a waiter…​

When you sit down at your table, you look at the menu unsure of what you should order. You ask the wait staff, “what do you think I should order for dessert?”

The waiter checks their notes and says “we’re out of the key lime pie, but people really like our banana creme pie. If you like pomegranate, we make pom ice-cream from scratch; and it’s hard to go wrong with the donut a la mode–it’s our most popular dessert.”

In this short exchange, the waiter first serves as a collector; they identify the desserts on the menu, accommodate current inventory conditions by checking their notes, and finally prepare themselves to talk about the characteristics of the desserts.

Next, they serve as a ranker; they mention items both high scoring in popularity (banana creme pie and donut a la mode), and a contextually high match item based on the patron’s features (if they like pomegranate).

Finally, they serve the recommendations verbally, including both explanatory features of their algorithm and multiple choices.

While this seems a bit cartoonish, remember to ground discussions of recommender systems in real world applications. One of the advantages of working in recommendation systems is that inspiration is always near by.

Simplest possible recommenders

We’ve established the components of a recommender, but to really make this practical, we’ll need to see this in action. While much of the book will be dedicated to practical recommender systems, first we’ll start with a toy and scaffold from there.

The trivial recommender

The absolute simplest recommender is not very interesting, but still is able to be demonstrated in the framework. It’s called the trivial recommender (TR) because it contains virtually no logic:

def get_trivial_recs() -> Optional[List[str]]:
   item_id = random.randint(0, MAX_ITEM_INDEX)

   if get_availability(item_id):
       return [item_id]
   return None

Notice that this recommender may either return a specific item_id or None. Also observe that this recommender takes no arguments, and MAX_ITEM_INDEX is referencing a variable out-of-scope. Software principles ignored, let’s think about the three components:

  • Collector: A random item_id is generated. The TR collects by checking the availability of item_id. One could argue that having access to item_id is also part of the collector’s responsibility. Conditional upon the availability, the collection of recommendable things is either [item_id] or None (recall that None is a collection in the set-theoretic sense).

  • Ranker: The TR ranks with a no-op; i.e. the ranking of 1 or 0 objects in a collection is the identity function on that collection, so we merely do nothing and move onto the next step.

  • Server: The TR serves recommendations by its return statements. The only schema that’s been specified above, is that the return type is Optional[List[str]].

The above recommender, which is not interesting or useful, provides a skeleton which we will add to as we develop further.

Most-popular-item recommender

The most-popular-item recommender (MPIR) is the simplest recommender that contains much or any utility. It’s unlikely that you’ll want to build applications around it, but it’s actually useful in tandem with other components in addition to providing a basis for further development.

It works, just as it says: it returns the most popular items.

def get_item_popularities() -> Optional[Dict[str, int]]:
    ...
        # Dict of pairs: (item-identifier, count times item chosen)
        return item_choice_counts
    return None

def get_most_popular_recs(max_num_recs: int) -> Optional[List[str]]:
    items_popularity_dict = get_item_popularities()
    if items_popularity_dict:
        sorted_items = sorted(
            items_popularity_dict.items(),
            key=lambda item: item[1]),
            reverse=True,
        )
        return [i[0] for i in sorted_items][:max_num_recs]
    return None

Here we assume that get_item_popularities has knowledge of all available items and how many times they’ve been chosen.

This recommender attempts to return the k most popular items which are available. While simple, this is a very useful recommender that serves as a great place to start when building a recommendation system. Additionally, we will see this example return over and over, because other recommenders use this core and iteratively improve the internal components.

  • Collector: The MPIR first makes a call to get_item_popularities which–via database or memory access–knows which items are available, and how many time they’ve been selected. For convenience, we assume that they’re returned as a dictionary with keys given by the string that identifies the item, and values the number of times that item has been chosen. We implicitly assume here that items not appearing in this list are not available.

  • Ranker: Here we see our first simple ranker: ranking by sorting on values. Because the collector has organized our data such that the values of the dictionary are the counts, we use the python built-in sorting function sorted. Note that we use the key to indicate that we wish to sort by the second element of the tuples–in this case equivalent to sorting by values–and we send the reverse flag to make our sort descending.

  • Server: Finally, we need to satisfy our API schema, which is again provided via the return type hint: Optional[List[str]]. This wants the return type to be the nullable list of item-identifier strings that we’re recommending, so we use a list comprehension to grab the first element of the tuples. But wait! Our function has this max_num_recs field–what might that be doing there? Of course, this is suggesting that our API schema is looking for no greater than max_num_recs in the response. We handle this via the slice operator, but take note that our return is between 0 and max_num_recs results.

Consider the possibilities at your fingertips equipped with the MPIR; recommending customers’ favorite item in each top level category could make for an simple but useful first stab at recommendations for ecommerce. The most popular video of the day may make for a good home page experience on your video site.

A Gentle Introduction to Jax

Since this book has Jax in the title we will provide a very gentle introduction to Jax here. The official documentation for Jax can be found at Jax Documentation. Jax is a framework for writing mathematical code in Python that is Just In Time (JIT) compiled. JIT compilation allows the same code to run on CPUs, GPUs, and TPUs. This makes it very easy to write performant code that takes advantage of the parallel processing power of vector processors. Additionally, one of the design philosophies of JAX is to support tensors and gradients as core concepts, making it an ideal tool for ML systems which utilize gradient-based learning on tensor shaped data.

The easiest way to play with Jax is probably via Google Colab, which is a hosted Python notebook on the web.

Basic Types, Initialization and Immutability

Let’s start by learning about Jax types. We will start by constructing a small three dimensional vector in Jax and pointing out some differences between Jax and Numpy.

import jax.numpy as jnp
import numpy as np

x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)

print(x)
[1. 2. 3.]

print(x.shape)
(3,)

print(x[0])
1.0

x[0] = 4.0
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable.

Jax’s interface is mostly simular to that of numpy. We import Jax’s version of numpy as jnp to distinguish numpy np by convention so that we know which version of a mathematical function we want to use. This is because sometimes we might want to run code on a vector processor like a GPU or TPU which we can use Jax for, or we might prefer to run some code on a CPU in numpy.

The first thing to notice is that Jax arrays have types. The typical float type is float32 which uses 32-bits to represent a floating point number. There are other types such as float64 which has greater precision or float16 which is a half precision type that usually only runs on some GPUs.

The other thing to note is Jax tensors have shape. This is usually a tuple so (3,) means a three dimensional vector along the first axis. A matrix has two axes and a tensor has three or more axes.

Now we come to places where Jax differs from Numpy. It is really important to pay attention to Jax the sharp bits to understand the differences between Jax and Numpy. Jax’s philosophy is about speed and purity. By making functions pure, that is, they do not have side effects, and by making data immutable, Jax is able to make some guarantees to the underlying accelerated linear algebra (XLA) library that it uses to talk to GPUs. The guarantees to XLA is that these functions applied to data can be run in parallel and have deterministic results without side effects and thus XLA is able to compile these functions and make them run much faster than if it was run just on Numpy.

You can see that modifying one element in x results in an error. Jax would prefer that the array x is replace rather than modified. One way to modify elements in an array would be to do it in numpy rather than Jax and convert Numpy array’s to Jax (for example using jnp.array(np_array)) when the subsequent code needs to run fast on immutable data.

Indexing and Slicing

Another important skill to learn is that of indexing and slicing arrays.

x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=jnp.int32)

# Print the whole matrix.
print(x)
[[1 2 3]
 [4 5 6]
 [7 8 9]]

# Print the first row.
print(x[0])
[1 2 3]


# Print the last row.
print(x[-1])
[7 8 9]

# Print the second column.
print(x[:, 1])
[2 5 8]

# Print every other element
print(x[::2, ::2])
[[1 3]
 [7 9]]

Numpy introduced indexing and slicing operations that allows us to access different parts of an array. In general the notation follows a start:end:stride convetion. The first element being where to start, the second being where to end, but not inclusive, and the stride being how many elements to skip over. The syntax is very similar to that of the Python range function. Slicing allows us to access views of a tensor elegantly. Slicing and indexing are important skills to master, especially when we start to manipulate tensors in batches, which we are wont to do to make the most use of acceleration hardware.

Broadcasting

Broadcasting is another feature of Numpy and Jax to be aware of. When a binary operation such as addition or multiply is applied to two tensors of different sizes, the tensor whose axes is if size 1 is lifted up in rank to match that of the larger sized tensor. For example if a tensor of shape (3,3) is multipled by a tensor of shape (3, 1) the rows of the second tensor are duplicated before the operation so that it looks like a tensor of shaoe (3, 3)

x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=jnp.int32)

# Scalar broadcasting.
y = 2 * x
print(y)
[[ 2  4  6]
 [ 8 10 12]
 [14 16 18]]

# Vector broadcasting. Axes with shape 1 are duplicated.
vec = jnp.reshape(jnp.array([0.5, 1.0, 2.0]), [3, 1])
y = vec * x
print(y)
[[ 0.5  1.   1.5]
 [ 4.   5.   6. ]
 [14.  16.  18. ]]

vec = jnp.reshape(vec, [1, 3])
y = vec * x
print(y)
[[ 0.5  2.   6. ]
 [ 2.   5.  12. ]
 [ 3.5  8.  18. ]]

The first case is the simplest, that of scalar multiplication. In this case the scalar is multipled throughout the matrix. In the second case, we have a vector of shape (3, 1) multiplying the matrix. The first row is multipied by 0.5, the second row is multiplied by 1.0 and the third rowa is multiplied by 2.0. However, if the vector has been reshaped to (1, 3), then the columns are multiplied by the succesive entries of the vector instead.

Random Numbers

Along with Jax’s philosopy of pure functions comes it’s particular way of handling random numbers. Because pure functions are those which do not cause side effects, a random number generator cannot modify the random number seed unlike other random number generators. Instead, Jax deals with random number keys whose state is updated explicitly.

import jax.random as random

key = random.PRNGKey(0)
x = random.uniform(key, shape=[3, 3])
print(x)
[[0.35490513 0.60419905 0.4275843 ]
 [0.23061597 0.6735498  0.43953657]
 [0.25099766 0.27730572 0.7678207 ]]

key, subkey = random.split(key)
x = random.uniform(key, shape=[3, 3])
print(x)
[[0.0045197  0.5135027  0.8613342 ]
 [0.06939673 0.93825936 0.85599923]
 [0.706004   0.50679076 0.6072922 ]]

y = random.uniform(subkey, shape=[3, 3])
print(y)
[[0.34896135 0.48210478 0.02053976]
 [0.53161216 0.48158717 0.78698325]
 [0.07476437 0.04522789 0.3543167 ]]

Jax first requires you to create a random number key from a seed. This key is then passed into random number generation functions like uniform to create random numbers in the 0 to 1 range. In order to create more random numbers however, Jax requires that you split the key into two parts. A new key that can be used to generate other keys, and a subkey that can be used to generate new random numbers. This allows Jax to deterministically and reliably reproduce random numbers even when there are many parallel operations calling the random number generator. One just splits a key into as many parallel operations as needed and the random numbers resulting are now randomly distributed but also reproduceable. This is a nice property when you want to reproduce experiments reliably.

Just In Time compilation

Where Jax starts to diverge from numpy in terms of execution speed is when we start using Just In Time (JIT) compilation. JITing code – transforming the code to be compiled just in time – allows the same code to run on CPU, GPUs, or TPUs.

import jax

x = random.uniform(key, shape=[2048, 2048]) - 0.5

def my_function(x):
  x = x @ x
  return jnp.maximum(0.0, x)

%timeit my_function(x).block_until_ready()
302 ms ± 9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

my_function_jitted = jax.jit(my_function)

%timeit my_function_jitted(x).block_until_ready()
294 ms ± 5.45 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

The JITed code is not that much faster on CPU but will be dramatically faster on a GPU or TPU backend. There is also some overhead of compilation when the function is called the first time which can skew the timing of the first call. Functions that can be JITed have restrictions, such as mostly calling Jax operations inside and having restrictions on loop operations. Variable length loops trigger frequent recompilations. The documentation JAX Jitting covers a lot of the nuances of getting functions to JIT compile.

Summary

While we’ve not done much math yet, we have gotten to the point where we may begin providing recommendations and implementing deeper logic into these components. We’ll start doing things that look like machine learning soon enough.

So far we have defined what a recommendation problem is, we set up the core architecture of our recommender system – the collector, the ranker, and the server, and we’ve shown a couple trivial recommenders which show how the pieces come together.

Next we’ll see the core relationship that recommendation systems seek to exploit–the user-item matrix. This matrix let’s us build a model of personalization which will lead to ranking.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.189.22.55