Sabrina J. Mielke

Sabrina J. Mielke

From PyTorch to JAX: towards neural net frameworks that purify stateful code

2020-03-09


Update 2021-07-01: I gave a talk at the Flax/JAX community week largely based on this blogpost---but made a bit more concise and punchy and including an example of flax' new linen API! The talk is half Google Slides and half Google Colab notebook, and the recording is on YouTube :)


Note: this post also exists as the original Colab notebook from which it was rendered—if you prefer that sort of thing.

JAX, Google's now-over-a-year-old Python library for machine learning and other numerical computing describes itself as “Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more”—and while that definition is certainly fitting, it is a bit intimidating. I would describe JAX as numpy, but on GPU, and then move on to the one feature we will be most concerned with in this post: its autodifferentiation capability, i.e., how to get gradients of some loss function your code computes with respect to you input parameters. If you haven't heard of JAX at all, I can recommend Skye Wanderman-Milne's talk at NeurIPS on JAX (or check out the corresponding slides). It's a cool framework with cool ideas!

That said, moving from PyTorch or Tensorflow 2 to JAX is a huge change: the fundamental way we build up computation and, more importantly, backpropagate through it is fundamentally different in the two! PyTorch builds up a graph as you compute the forward pass, and one call to backward() on some “result” node then augments each intermediate node in the graph with the gradient of the result node with respect to that intermediate node. JAX on the other hand makes you express your computation as a Python function, and by transforming it with grad() gives you a gradient function that you can evaluate like your computation function—but instead of the output it gives you the gradient of the output with respect to (by default) the first parameter that your function took as input:

PyTorch vs. JAX on a very simple 1D linear “layer”

This has consequences for how you write code and build up models in both frameworks, of course. So when you're used to tape-based auto-differentiation and working with stateful objects in PyTorch or Tensorflow 2, coming to JAX may be quite a shock—and while running grad() on numpy-oneliners like the one above (which we will actually run later below) is cool and all, you wonder what a minimal example for, say, a language model would look like (language models aren't quite as straightforward to implement as ResNets: they can have dynamic structures that aren't always nicely divisible into “layers”).

Maybe you decided to look at libraries like flax, trax, or haiku and what you see at least in the ResNet examples looks not too dissimilar from any other framework: define some layers, run some trainers... but what is it that actually happens there? What's the route from these tiny numpy functions to training big hierarchical neural nets?

That's the niche this post is trying to fill. We will:

  1. quickly recap a stateful LSTM-LM implementation in a tape-based gradient framework, specifically PyTorch,
  2. see how PyTorch-style coding relies on mutating state, learn about mutation-free pure functions and build (pure) zappy one-liners in JAX,
  3. step-by-step go from individual parameters to medium-size modules by registering them as pytree nodes,
  4. combat growing pains by building fancy scaffolding, and controlling context to extract initialized parameters purify functions and
  5. realize that we could get that easily in a framework like DeepMind's haiku using its transform mechanism.

Things we will not do in this tutorial: build state-of-the-art models or elegant, “idiomatic” codebases—we just try to build minimal examples that exhibit the complexity we are looking for. We also won't cover batching (which is easy with vmap()), distributing (pmap()), and XLA-compiling your code (jit()) in JAX—all of which are really cool features, but beside the point.

Let's get started!

1. A LSTM-LM in PyTorch

To make sure we're on the same page, let's implement the language model I want to work towards in PyTorch. To keep the comparison straightforward, we will implement things from scratch as much as possible in all three approaches. Let's start with an LSTMCell that holds some parameters:

import torch

class LSTMCell(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super(LSTMCell, self).__init__()
        self.weight_ih = torch.nn.Parameter(torch.rand(4*out_dim, in_dim))
        self.weight_hh = torch.nn.Parameter(torch.rand(4*out_dim, out_dim))
        self.bias = torch.nn.Parameter(torch.zeros(4*out_dim,))
        
    def forward(self, inputs, h, c):
        ifgo = self.weight_ih @ inputs + self.weight_hh @ h + self.bias
        i, f, g, o = torch.chunk(ifgo, 4)
        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        g = torch.tanh(g)
        o = torch.sigmoid(o)
        new_c = f * c + i * g
        new_h = o * torch.tanh(new_c)
        return (new_h, new_c)

Next, build a super simple 1-layer LSTM language model using this cell. Note that to keep the example simple, we will just use a simple matrix for embeddings. This and the learned (h,c)0 will demonstrate how individual paramaters are registered in our solutions.

class LSTMLM(torch.nn.Module):
    def __init__(self, vocab_size, dim=17):
        super().__init__()
        self.cell = LSTMCell(dim, dim)
        self.embeddings = torch.nn.Parameter(torch.rand(vocab_size, dim))
        self.c_0 = torch.nn.Parameter(torch.zeros(dim))
    
    @property
    def hc_0(self):
        return (torch.tanh(self.c_0), self.c_0)

    def forward(self, seq, hc):
        loss = torch.tensor(0.)
        for idx in seq:
            loss -= torch.log_softmax(self.embeddings @ hc[0], dim=-1)[idx]
            hc = self.cell(self.embeddings[idx,:], *hc)
        return loss, hc
    
    def greedy_argmax(self, hc, length=6):
        with torch.no_grad():
            idxs = []
            for i in range(length):
                idx = torch.argmax(self.embeddings @ hc[0])
                idxs.append(idx.item())
                hc = self.cell(self.embeddings[idx,:], *hc)
        return idxs

To demonstrate that it works, let's train:

torch.manual_seed(0)

# As training data, we will have indices of words/wordpieces/characters,
# we just assume they are tokenized and integerized (toy example obviously).

import jax.numpy as jnp
vocab_size = 43  # prime trick! :)
training_data = jnp.array([4, 8, 15, 16, 23, 42])

lm = LSTMLM(vocab_size=vocab_size)
print("Sample before:", lm.greedy_argmax(lm.hc_0))

bptt_length = 3  # to illustrate hc.detach-ing

for epoch in range(101):
    hc = lm.hc_0
    totalloss = 0.
    for start in range(0, len(training_data), bptt_length):
        batch = training_data[start:start+bptt_length]
        loss, (h, c) = lm(batch, hc)
        hc = (h.detach(), c.detach())
        if epoch % 50 == 0:
            totalloss += loss.item()
        loss.backward()
        for name, param in lm.named_parameters():
            if param.grad is not None:
                param.data -= 0.1 * param.grad
                del param.grad
    if totalloss:
        print("Loss:", totalloss)

print("Sample after:", lm.greedy_argmax(lm.hc_0))
Sample before: [42, 34, 34, 34, 34, 34]
Loss: 25.953862190246582
Loss: 3.7642268538475037
Loss: 1.9537211656570435
Sample after: [4, 8, 15, 16, 23, 42]

So, this is all nice and perhaps more intuitive than old graph-mode Theano or Tensorflow, but it comes with some annoying gotchas. Even while writing this up, I struggled with making sure I detach-ed in the right places and marked the right things as Parameters to make sure just the right amount of nodes in the computation graph (which PyTorch assembles through our statements in the background) are marked as intermediate and cleaned up at the right time.

2. Pure functions

To understand how JAX handles this issue, we need to understand the concept of pure functions. If you've done some functional programming before, you might be familiar with that concept: a pure function is like a function or formula in math. It defines how an output value is obtained from some input values. What's important is that it has no “side effects”: no part of the function should access or even mutate any global state.

The way we wrote our code in PyTorch was very much stateful and full of mutating state, making reasoning about it and optimizing it a bit tricky. JAX therefore chooses to constrain a programmer to pure functions that don't do any of that.

But let's look at some examples of pure functions before we dive into JAX. To confirm that a function is pure, in a nutshell, these criteria have to hold:

  1. It shouldn't matter when and in what context you execute the function—as long as the inputs are the same, the outputs should be the same.
  2. Whether we executed the function zero, one, or many times should be absolutely impossible to discern after the fact.

See how all the impure functions below violate at least one of these constraints:

import random
import time
nr_executions = 0

def pure_fn_1(x):
    return 2 * x

def pure_fn_2(xs):
    ys = []
    for x in xs:
        # Mutating stateful variables *inside* the function is fine!
        ys.append(2 * x)
    return ys

def impure_fn_1(xs):
    # Mutating arguments has lasting consequences outside the function! :(
    xs.append(sum(xs))
    return xs

def impure_fn_2(x):
    # Very obviously mutating global state is bad...
    global nr_executions
    nr_executions += 1
    return 2 * x

def impure_fn_3(x):
    # ...but just accessing it is, too, because now the function depends on the
    # execution context!
    return nr_executions * x

def impure_fn_4(x):
    # Things like IO are classic examples of impurity.
    # All three of the following lines are violations of purity:
    print("Hello!")
    user_input = input()
    execution_time = time.time()
    return 2 * x

def impure_fn_5(x):
    # Which constraint does this violate? Both, actually! You access the current
    # state of randomness *and* advance the number generator!
    p = random.random()
    return p * x

Let's see a pure function that JAX operates on: the example from the intro figure.

# (almost) 1-D linear regression
def f(w, x):
    return w * x

print(f(13., 42.))

546.0

So far, so uneventful. What JAX now allows you to do is to take a function like this and transform it into a function that instead of returning the result, returns the gradient of the result with respect to (by default) the first parameter!

import jax
import jax.numpy as jnp

# Gradient: with respect to weights! JAX uses the first argument by default.
df_dw = jax.grad(f)

def manual_df_dw(w, x):
    return x

assert df_dw(13., 42.) == manual_df_dw(13., 42.)

print(df_dw(13., 42.))

42.0

Everything up to here you probably have seen in the JAX README and it kinda makes sense. But how do we get from here to big modules like the one in our PyTorch code?

First, let's add a bias term and try to wrap the 1-D linear regressor we get into an object like we're used to: a kind of LinearRegressor “layer“:

class LinearRegressor():
    def __init__(self, w, b):
        self.w = w
        self.b = b

    def predict(self, x):
        return self.w * x + self.b

    def rms(self, xs: jnp.ndarray, ys: jnp.ndarray):
        return jnp.sqrt(jnp.sum(jnp.square(self.w * xs + self.b - ys)))

my_regressor = LinearRegressor(13., 0.)

# A kind of loss fuction, used for training
xs = jnp.array([42.0])
ys = jnp.array([500.0])
print(my_regressor.rms(xs, ys))

# Prediction for test data
print(my_regressor.predict(42.))

46.0
546.0

So far, so good. So how do we get gradients to train? We would need a pure function that has our parameters as arguments somewhere, maybe like this:

def loss_fn(w, b, xs, ys):
    my_regressor = LinearRegressor(w, b)
    return my_regressor.rms(xs=xs, ys=ys)

# We use argnums=(0, 1) to tell JAX to give us
# gradients wrt first and second parameter.
grad_fn = jax.grad(loss_fn, argnums=(0, 1))

print(loss_fn(13., 0., xs, ys))
print(grad_fn(13., 0., xs, ys))

46.0
(DeviceArray(42., dtype=float32), DeviceArray(1., dtype=float32))

Convince yourself that that's true :) Now, this is workable, but clearly enumerating all parameters in the head of loss_fn isn't feasible.

Luckily, JAX is not just comfortable differentiating with respect to scalars, vectors, and matrices, but also with respect to a number of tree-like data structures that it calls pytrees—and they include python dicts:

def loss_fn(params, xs, ys):
    my_regressor = LinearRegressor(params['w'], params['b'])
    return my_regressor.rms(xs=xs, ys=ys)

grad_fn = jax.grad(loss_fn)

print(loss_fn({'w': 13., 'b': 0.}, xs, ys))
print(grad_fn({'w': 13., 'b': 0.}, xs, ys))

46.0
{'b': DeviceArray(1., dtype=float32), 'w': DeviceArray(42., dtype=float32)}

So this already looks nicer! We could write a training loop like this:

params = {'w': 13., 'b': 0.}

for _ in range(15):
    print(loss_fn(params, xs, ys))
    grads = grad_fn(params, xs, ys)
    for name in params.keys():
        params[name] -= 0.002 * grads[name]

# Now, predict:
LinearRegressor(params['w'], params['b']).predict(42.)

46.0
42.47003
38.940002
35.410034
31.880066
28.350098
24.820068
21.2901
17.760132
14.230164
10.700165
7.170166
3.6401978
0.110198975
3.4197998
DeviceArray(500.1102, dtype=float32)

Note that we can already make use of a bit more JAX helpers for the updating itself: since params and grads have the same (tree-like) structure, we can imagine laying them on top and creating a new tree whose values everywhere are a “combination” of the two trees like this:

def update_combiner(param, grad, lr=0.002):
    return param - lr * grad

params = jax.tree_multimap(update_combiner, params, grads)
# instead of:
# for name in params.keys():
#    params[name] -= 0.1 * grads[name]

3. A lacking impromptu solution: registering classes as custom pytree types

So, it works. But going back and forth between our object and the params dict is a bit annoying. One thing we can do to simplify the process is allow JAX to see our so-far rather opaque LinearRegressor class as a data structure, allowing it to be used in place of the dict params we have!

For this we will need to tell JAX how one can break our class down into a list of parameters and auxiliary information (called flattening) and how then, it can reassemble the class with perhaps changed parameters and that auxiliary information (called unflattening):

def flatten_linear_regressor(regressor):
    leaves = (regressor.w, regressor.b)
    aux = None  # we don't need auxiliary information for this simple class
    return (leaves, aux)

# careful, switched argument order! (unfortunate baggage from the past...)
def unflatten_linear_regressor(_aux, leaves):
    w, b = leaves
    return LinearRegressor(w, b)

jax.tree_util.register_pytree_node(
    LinearRegressor,
    flatten_linear_regressor,
    unflatten_linear_regressor,
)

Now we can use our regressor throughout, making for a very easy loss_fn:

def loss_fn(regressor, xs, ys):
    return regressor.rms(xs=xs, ys=ys)

grad_fn = jax.grad(loss_fn)

print(loss_fn(LinearRegressor(w=13., b=0.), xs, ys))

46.0

Now what do you think the function grad_fn returns? It used to be a dict of gradients in the shape of the params dict, but now...

print(grad_fn(LinearRegressor(w=13., b=0.), xs, ys))
	<__main__.LinearRegressor object at 0x7f4ea586b128>

...it is a LinearRegressor object that has gradients where the params used to be! Again, we can use jax.tree_util.tree_multimap to combine:

model = LinearRegressor(w=13., b=0.)

for _ in range(15):
    print(loss_fn(model, xs, ys))
    grads = grad_fn(model, xs, ys)
    model = jax.tree_multimap(update_combiner, model, grads)

# Now, predict:
model.predict(42.)

46.0
42.47003
38.940002
35.410034
31.880066
28.350098
24.820068
21.2901
17.760132
14.230164
10.700165
7.170166
3.6401978
0.110198975
3.4197998
DeviceArray(500.1102, dtype=float32)

So trying to make your modules and your models all flatten- and unflatten-able is one way you can avoid param dict handling (you're essentially writing it once for your flatten and unflatten and then let JAX call them).

The downside is that this requires you to write these flattening/unflattening functions for every module (though much of that can certainly be re-used in a base class like nn.Module). Even for small things like our LSTM-LM from above this solution looks plenty ugly:

class PytreeLSTMCell():
    def __init__(self, weight_ih, weight_hh, bias):
        self.weight_ih = weight_ih
        self.weight_hh = weight_hh
        self.bias = bias

    def __call__(self, inputs, h, c):
        ifgo = self.weight_ih @ inputs + self.weight_hh @ h + self.bias
        i, f, g, o = jnp.split(ifgo, indices_or_sections=4, axis=-1)
        i = jax.nn.sigmoid(i)
        f = jax.nn.sigmoid(f)
        g = jnp.tanh(g)
        o = jax.nn.sigmoid(o)
        new_c = f * c + i * g
        new_h = o * jnp.tanh(new_c)
        return (new_h, new_c)

jax.tree_util.register_pytree_node(
    PytreeLSTMCell,
    lambda c: ((c.weight_ih, c.weight_hh, c.bias), None),
    lambda _, ws: PytreeLSTMCell(*ws),
)

class PytreeLSTMLM():
    def __init__(self, cell, embeddings, c_0):
        self.cell = cell
        self.embeddings = embeddings
        self.c_0 = c_0
    
    @property
    def hc_0(self):
        return (jnp.tanh(self.c_0), self.c_0)

    @jax.jit  # jit compiles with XLA, so we are a lot faster (try it without!).
    def forward(self, seq, hc):
        loss = 0.
        for idx in seq:
            loss -= jax.nn.log_softmax(self.embeddings @ hc[0])[idx]
            hc = self.cell(self.embeddings[idx,:], *hc)
        return loss, hc

    def greedy_argmax(self, hc, length=6):
        idxs = []
        for i in range(length):
            idx = jnp.argmax(self.embeddings @ hc[0])
            idxs.append(int(idx))
            hc = self.cell(self.embeddings[idx,:], *hc)
        return idxs

# These two functions are just a whole lot of unreadable YIKES
def flatten_whole_lstmlm(lm):
    flat_cell_weights, flat_cell_aux = jax.tree_util.tree_flatten(lm.cell)
    return tuple(flat_cell_weights) + (lm.embeddings, lm.c_0), flat_cell_aux

def unflatten_whole_lstmlm(aux, weights):
    flat_cell_weights = weights[:-2]
    embeddings, c_0 = weights[-2:]
    cell = jax.tree_util.tree_unflatten(aux, flat_cell_weights)
    return PytreeLSTMLM(cell, embeddings, c_0)

jax.tree_util.register_pytree_node(
    PytreeLSTMLM,
    flatten_whole_lstmlm,
    unflatten_whole_lstmlm,
)

Look at that. Nasty. (Apart from the horrors in the handler, you might've noticed that I added the @jax.jit annotation to the forward method just to make it run a bit faster—this, like grad() only works cause we made our class transparent to JAX—but it doesn't affect functionality, as you can see when you try to remove it. We won't talk about it more in this tutorial.)

But just to show you that it works, let's train it like we trained the PyTorch model. To do that, we will need initial parameters, and that brings us to another initially frightening aspect of JAX: you always have to specify where your randomness comes from!

Remember that we said that randomness is also a source of impurity. But, really, in a computer, you never get real randomness, you always get pseudo-randomness: random looking numbers generated from a seed. So, if you have a seed, generating pseudo-random numbers from it is very much a pure operation. This means that in JAX, we will always need to give every function that samples something (e.g., in initializing parameters) a seed to draw randomness from. Usually you'd nicely thread these through your program, but to keep it short we'll just hardcode them here. (Later you'll see us using them a bit more properly.)

hid_dim = 17
lm = PytreeLSTMLM(
    PytreeLSTMCell(
        jax.random.uniform(jax.random.PRNGKey(1234), (4*hid_dim, hid_dim)),
        jax.random.uniform(jax.random.PRNGKey(4321), (4*hid_dim, hid_dim)),
        jnp.zeros((4*hid_dim,)),
    ),
    jax.random.uniform(jax.random.PRNGKey(123), (vocab_size, hid_dim)),
    jnp.zeros((hid_dim,)),
)

# We are using a different LR, i.e., a new optimizer/combiner here
update_combiner_01 = lambda p, g: update_combiner(p, g, lr=0.1)

To make sure JAX takes note of us constructing (h0,c0) from c0, which we deem a parameter, that too has to happen inside the “pure” function that we want to call jax.grad on, so we end up with this:

def pure_loss_fn(lm, seq, hc):
    if hc is None:
        hc = lm.hc_0
    loss, _ = lm.forward(seq, hc)
    return loss

grad_fn = jax.grad(pure_loss_fn)

print("Sample before:", lm.greedy_argmax(lm.hc_0))

bptt_length = 3
for epoch in range(101):
    totalloss = 0.
    hc = None
    for start in range(0, len(training_data), bptt_length):
        batch = training_data[start:start+bptt_length]
        loss, new_hc = lm.forward(batch, hc if hc else lm.hc_0)
        if epoch % 50 == 0:
            totalloss += loss.item()
        grad_lm = grad_fn(lm, batch, hc)
        lm = jax.tree_multimap(update_combiner_01, lm, grad_lm)
        hc = new_hc
    if totalloss:
        print("Loss:", totalloss)

print("Sample after:", lm.greedy_argmax(lm.hc_0))

Sample before: [0, 25, 25, 25, 25, 25]
Loss: 26.58103370666504
Loss: 4.404336929321289
Loss: 2.6979217529296875
Sample after: [4, 8, 15, 16, 23, 42]

Well, it works! But... for one, writing the handlers was rather nasty. We could imagine a nn.Module-like class like in PyTorch that automatically builds these handlers from its parameters and submodules, if we give it some method to register and keep track of them—hold that thought for later!—this would allow us to write code that was a bit closer to PyTorch.

Note: Since I wrote this blogpost, Sasha Rush approached me with a suggestion on how to turn this idea into a fully-fledged framework and I figured I'd give it a shot and implement his idea in JAX! The result, parallax, takes care of all the flattening/unflattening and registers parameters not as strings like flax or haiku, but as normal ndarrays using nice type annotations. Check it out and contribute! :)

However, there's also another annoyance with this code: we had to initialize all our parameters outside of the classes, far away from where we use them. That's really undesirable! Imagine we want to exchange our LSTM cell for a GRU cell. We would have to look up how exactly the GRU cell is implemented and what parameters of which shapes it uses, and then rewrite our initialization block to cater to that.

In contrast, what we are going to work our way towards is a framework that allows us to specify initializations for parameters inside the modules they are used in themselves, and, more impressively yet, turn stateful looking interactions with these registered parameters into pure functions that JAX can easily operate on!

4. Fancy objects manage parameters

To see the basic idea in action, consider the good old w * x example:Purifying an impure function

On the left is the class we'd like to write PyTorch-style, and on the right we have a class that houses a pure function like we've been writing manually before.

The key insight is that we can very easily generate that “purified” function by controlling the context—specifically by setting self.w to reference the tensor w that was given as input, thus making sure that the formerly purity-violating access now only accesses inputs of the function, making the entire thing pure. What's important is that we don't need to know what f is doing with x or w when we write purified_f: all the logic of f is still happening, just now with a controlled context.

Let's implement a “purifying” regressor that does exactly this. We have two parameters, w and b, so we're going back to the parameter dictionary approach to keep them together. We will also wrap the problematic access (the thick arrow) in a method call get_param instead of referencing self.w (or, really self.params) immediately. That (for now) will just make sure we are actually running in “pure mode” and throw an error otherwise.

class PurifyingRegressor():

    # This is the function the user wants to write.
    def impure_user_fn(self, x):
        w = self.get_param('w')
        b = self.get_param('b')
        return w * x + b

    # This is the wrapping/mode logic.
    params = {}
    is_running_pure = False

    def get_param(self, name):
        if not self.is_running_pure:
            raise Exception("We can only call this when wrapped!")
        else:
            return self.params[name]

    # This is the function that we want JAX to use!
    # Note how it defers all calculation and logic to
    # the user-written and just provides *context*!
    def pure_wrapped_user_fn(self, params, *args):
        self.is_running_pure = True
        self.params = params
        result = self.impure_user_fn(*args)
        self.is_running_pure = False
        return result

Convince yourself that it works:

pure_predict_fn = PurifyingRegressor().pure_wrapped_user_fn(*args)

# Show that it's equal to this manually purified function:
manually_pure_predict_fn = lambda params, x: params['w'] * x + params['b']

args = ({'w': 13., 'b': 0.}, 42.)

assert pure_predict_fn(*args) == manually_pure_predict_fn(*args)

If that all makes sense to you, let's move on to introducing a second mode: an initialization mode! The reason is that remembering all parameters outside of the class and the code like we did above is really annoying. It would be much nicer if we could give an initialization for a parameter as we use it (this will also allow us to do some shape inference, e.g., only specify the output dimensionality of a linear layer and infer the input dimension from the incoming data).

To do just that, we will extend get_param to also take an initializer argument (and extract all this magic that we're writing into a simple nn.module-like class the others will be able to inherit from):

class FlatFancyNNModule():
    def __init__(self, name):
        self.name = name
        self.params = None
        self.are_we_pure = False
        self.are_we_initializing = False

    # This is where, depending on mode, we'll give the user-written
    # impure function different things!
    def get_param(self, name, shape, initializer):
        if self.are_we_initializing:
            if name not in self.params:
                self.params[name] = initializer(shape)
            return self.params[name]
        elif self.are_we_pure:
            return self.params[name]
        else:
            raise Exception("Can't access parameters outside of context!")

    # This function calls the initializers to give us an initial params dict.    
    def initial_params(self, method, *args):
        self.are_we_initializing = True
        self.params = {}
        method(*args)
        self.are_we_initializing = False
        return self.params

    # This function returns a pure version of the function.
    def purify_method(self, method):
        def pure_method(params, *args):
            self.are_we_pure = True
            self.params = params
            result = method(*args)
            self.are_we_pure = False
            self.params = None
            return result
        return pure_method

So, to recap, there are now two “modes” that an impure function we wrote can be called in:

  1. initialization mode, where accessing a parameter calls the initializer, the result is a params dict with all these initial values
  2. pure running mode, where we simulate purity by making sure the “stateful” class member params is actually the pure argument that is fed in.

Let's test it on our simple example! An implementation of a user-written module will look like this:

class FancyLinearRegressor(FlatFancyNNModule):
    def __init__(self, name="linreg"):
        super().__init__(name=name)
    
    def predict(self, x):
        # Our "initializers" are a bit simple here.
        w = self.get_param('w', (0,), lambda _shape: 13.)
        b = self.get_param('b', (0,), lambda _shape: 0.)
        return w * x + b

    def rms(self, xs, ys):
        # Same as before.
        w = self.get_param('w', (0,), lambda _shape: 13.)
        b = self.get_param('b', (0,), lambda _shape: 0.)
        return jnp.sqrt(jnp.sum(jnp.square(w * xs + b - ys)))

And this is how we would transform it:

my_regressor = FancyLinearRegressor()

sample_input = 999.  # a placeholder, necessary so the method can be run completely
params = my_regressor.initial_params(my_regressor.predict, sample_input)

print("Params:", params)

pure_predict_fn = my_regressor.purify_method(my_regressor.predict)

pure_predict_fn(params, 42.)

Params: {'w': 13.0, 'b': 0.0}
546.0

Nice! And of course, we can also train:

pure_rms_fn = my_regressor.purify_method(my_regressor.rms)
grad_rms_fn = jax.grad(pure_rms_fn)

params = my_regressor.initial_params(my_regressor.predict, sample_input)
for _ in range(15):
    print(pure_rms_fn(params, xs, ys))
    grads = grad_rms_fn(params, xs, ys)
    params = jax.tree_multimap(update_combiner, params, grads)

pure_predict_fn(params, 42.)

46.0
42.47003
38.940002
35.410034
31.880066
28.350098
24.820068
21.2901
17.760132
14.230164
10.700165
7.170166
3.6401978
0.110198975
3.4197998
DeviceArray(500.1102, dtype=float32)

And it works! So now I hope you really want to move on and see our LSTM-LM implemented—but if you were to do that right now, you'd notice an issue: our implementation only works on parameters that are directly requested inside the module itself! It doesn't support submodules—but we definitely want those for our LM, which will both own some parameters and own a LSTMCell class that owns its own. So, to make sure that submodules of a given object that are also FancyNNModule descendants follow suit, we'll update our implementation also keep a list of submodules around and recurse through all those on all mode changes:

class FancyNNModule():
    def __init__(self, name):
        self.name = name
        self.params = None
        self.are_we_pure = False
        self.are_we_initializing = False
        # Implementations should register their submodules here!
        self.submodules = {}

    # This is where, depending on mode, we'll give the user-written
    # impure function different things!
    def get_param(self, name, shape, initializer):
        if self.are_we_initializing:
            if name not in self.params:
                self.params[name] = initializer(shape)
            return self.params[name]
        elif self.are_we_pure:
            return self.params[name]
        else:
            raise Exception("Can't access parameters outside of context!")

    def set_are_we_pure(self, ispure):
        if ispure: self.params = {}
        self.are_we_pure = ispure
        for sm in self.submodules.values():
            sm.set_are_we_pure(ispure)
        if not ispure: self.params = None

    def set_are_we_initializing(self, isinit):
        if isinit:
            self.params = {}
        self.are_we_initializing = isinit
        for sm in self.submodules.values():
            sm.set_are_we_initializing(isinit)

    # This method gathers params from this module and all submodules.
    def gather_params(self):
        params = self.params
        for sm_name, sm in self.submodules.items():
            for name, value in sm.gather_params().items():
                params[sm_name + "/" + name] = value
        return params

    # This method spreads out params into self and all submodules.
    def spread_params(self, params):
        for name, value in params.items():
            path = name.split("/")
            if len(path) == 1:
                self.params[name] = value
            else:
                self.submodules[path[0]].spread_params({"/".join(path[1:]): value})

    # This function calls the initializers to give us an initial params dict.    
    def initial_params(self, method, *args):
        self.set_are_we_initializing(True)
        method(*args)
        self.set_are_we_initializing(False)
        params = self.gather_params()
        return params

    # This function returns a pure version of the function.
    def purify_method(self, method):
        def pure_method(params, *args):
            self.set_are_we_pure(True)
            self.spread_params(params)
            result = method(*args)
            self.set_are_we_pure(False)
            return result
        return pure_method

Note that this is still a very naive implementation that will not work for general cases, but it will serve to illustrate the rough idea—a real implementation would properly manage a frame stack.

But, foregoing that, let's finally implement the LSTM-LM for the third time (this time, we'll use jax.random.split for threading randomness, yes through an ugly global variable... but shhhh...):

random_key = jax.random.PRNGKey(0)

def unif_initializer(shape):
    global random_key
    sample_key, random_key = jax.random.split(random_key)
    return jax.random.uniform(sample_key, shape)

class FancyLSTMCell(FancyNNModule):
    def __init__(self, hid_dim, name="lstmcell"):
        super().__init__(name=name)
        self.hid_dim = hid_dim

    def __call__(self, inputs, h, c):
        weight_ih = self.get_param("weight_ih",
                (4*self.hid_dim, self.hid_dim), unif_initializer)
        weight_hh = self.get_param("weight_hh",
                (4*self.hid_dim, self.hid_dim), unif_initializer)
        bias = self.get_param("bias",
                (4*self.hid_dim,), lambda shape: jnp.zeros(shape))

        ifgo = weight_ih @ inputs + weight_hh @ h + bias
        i, f, g, o = jnp.split(ifgo, indices_or_sections=4, axis=-1)
        i = jax.nn.sigmoid(i)
        f = jax.nn.sigmoid(f)
        g = jnp.tanh(g)
        o = jax.nn.sigmoid(o)
        new_c = f * c + i * g
        new_h = o * jnp.tanh(new_c)
        return (new_h, new_c)

class FancyLSTMLM(FancyNNModule):
    def __init__(self, vocab_size, dim, name="lstmlm"):
        super().__init__(name=name)
        self.vocab_size = vocab_size
        self.dim = dim
        # Now create a submodule and register it!
        self.cell = FancyLSTMCell(dim)
        self.submodules[self.cell.name] = self.cell
    
    @property
    def hc_0(self):
        _c_0 = self.get_param("c_0",
                (self.dim,), lambda shape: jnp.zeros(shape))
        return (jnp.tanh(_c_0), _c_0)

    def forward(self, seq, hc):
        loss = 0.
        embeddings = self.get_param("embeddings",
                (self.vocab_size, self.dim), unif_initializer)
        for idx in seq:
            loss -= jax.nn.log_softmax(embeddings @ hc[0])[idx]
            hc = self.cell(embeddings[idx,:], *hc)
        return loss, hc

    def greedy_argmax(self, hc, length=6):
        idxs = []
        embeddings = self.get_param("embeddings",
                (self.vocab_size, self.dim), unif_initializer)
        for i in range(length):
            idx = jnp.argmax(embeddings @ hc[0])
            idxs.append(int(idx))
            hc = self.cell(embeddings[idx,:], *hc)
return idxs

Again, we have to make sure we “catch” the (h,c)0 construction in our transformation process, so we define three pure functions and start training:

lm = FancyLSTMLM(vocab_size, 17)

# Since jitting and our mutable-state-galore "tracer" don't play nice, only
# call jit on the outermost level: the actually pure function!
pure_sample_fn = lm.purify_method(
    lambda: lm.greedy_argmax(lm.hc_0))
pure_forward_fn = jax.jit(lm.purify_method(
    lambda seq, hc: lm.forward(seq, hc if hc else lm.hc_0)))
grad_loss_fn = jax.jit(jax.grad(lm.purify_method(
    lambda seq, hc: lm.forward(seq, hc if hc else lm.hc_0)[0])))

params = lm.initial_params(lambda: lm.forward(jnp.array([0]), lm.hc_0))
print("All parameters, recursively found:", list(params.keys()))

print("Sample before:", pure_sample_fn(params))

bptt_length = 3
for epoch in range(101):
    totalloss = 0.
    hc = None
    for start in range(0, len(training_data), bptt_length):
        batch = training_data[start:start+bptt_length]
        loss, new_hc = pure_forward_fn(params, batch, hc)
        if epoch % 50 == 0:
            totalloss += loss.item()
        grads = grad_loss_fn(params, batch, hc)
        params = jax.tree_multimap(update_combiner_01, params, grads)
        hc = new_hc
    if totalloss:
        print("Loss:", totalloss)

print("Sample after:", pure_sample_fn(params))

All parameters, recursively found: ['c_0', 'embeddings', 'lstmcell/weight_ih', 'lstmcell/weight_hh', 'lstmcell/bias']
Sample before: [0, 20, 20, 20, 20, 20]
Loss: 24.401604652404785
Loss: 3.0977354049682617
Loss: 1.744471549987793
Sample after: [4, 8, 15, 16, 23, 42]

5. We could've used haiku.transform all along

Here's the plot twist: what we've done is basically implement a poor version of haiku's transform()! The function haiku.transform(impure_fn) return both an init_fn and an apply_fn. The init_fn corresponds to our initial_params() method, the apply_fn is the “purified” version of the impure_fn we provide.

So, without further ado, let's do one last LSTM implementation, using haiku.transform() and a mixture of self-allocated parameters and haiku layers (which, again, don't take as a particularly pretty canonical example):

import haiku as hk

class HaikuLSTMCell(hk.Module):
    def __init__(self, in_dim, out_dim, name=None):
        super().__init__(name=name or "lstmcell")
        self.in_dim = in_dim
        self.out_dim = out_dim

    def __call__(self, inputs, h, c):
        weight_ih = hk.get_parameter("weight_ih",
                (4*self.out_dim, self.in_dim),
                init=hk.initializers.UniformScaling())
        weight_hh = hk.get_parameter("weight_hh",
                (4*self.out_dim, self.out_dim),
                init=hk.initializers.UniformScaling())
        bias = hk.get_parameter("bias",
                (4*self.out_dim,),
                init=hk.initializers.Constant(0.0))

        ifgo = weight_ih @ inputs + weight_hh @ h + bias
        i, f, g, o = jnp.split(ifgo, indices_or_sections=4, axis=-1)
        i = jax.nn.sigmoid(i)
        f = jax.nn.sigmoid(f)
        g = jnp.tanh(g)
        o = jax.nn.sigmoid(o)
        new_c = f * c + i * g
        new_h = o * jnp.tanh(new_c)
        return (new_h, new_c)

class HaikuLSTMLM(hk.Module):
    def __init__(self, vocab_size, dim, name=None):
        super().__init__(name=name or "lstmlm")
        _c0 = hk.get_parameter("c_0",
                (dim,),
                init=hk.initializers.TruncatedNormal(stddev=0.1))
        self.hc_0 = (jnp.tanh(_c0), _c0)
        self.embeddings = hk.Embed(vocab_size, dim)
        self.cell = HaikuLSTMCell(dim, dim)

    # @jax.jit
    def forward(self, seq, hc):
        loss = 0.
        for idx in seq:
            loss -= jax.nn.log_softmax(self.embeddings.embeddings @ hc[0])[idx]
            hc = self.cell(self.embeddings(idx), *hc)
        return loss, hc

    def greedy_argmax(self, hc, length=6):
        idxs = []
        for i in range(length):
            idx = jnp.argmax(self.embeddings.embeddings @ hc[0])
            idxs.append(int(idx))
            hc = self.cell(self.embeddings(idx), *hc)
        return idxs

Note that because Haiku mandates that all __init__ calls of haiku.Modules happen inside a transform, we just go ahead and also allocate our parameters and haiku modules right there—as long as that __init__ is part of the impure function you transform in the end, that's no problem:

def impure_sample_fn():
    lm = HaikuLSTMLM(vocab_size, 17)
    return lm.greedy_argmax(lm.hc_0)

def impure_forward_fn(seq, hc):
    lm = HaikuLSTMLM(vocab_size, 17)
    return lm.forward(seq, hc if hc else lm.hc_0)

_, pure_sample_fn = hk.transform(impure_sample_fn)
init_fn, nojit_pure_forward_fn = hk.transform(impure_forward_fn)
_, nojit_pure_loss_fn = hk.transform(lambda *args: impure_forward_fn(*args)[0])

pure_forward_fn = jax.jit(nojit_pure_forward_fn)
pure_loss_fn = jax.jit(nojit_pure_loss_fn)
grad_loss_fn = jax.jit(jax.grad(nojit_pure_loss_fn))

rng = jax.random.PRNGKey(0)  # Haiku actually manages the random number generator :)
params = init_fn(rng, jnp.array([0]), None)

from haiku._src.data_structures import frozendict
def print_params(params):
    return [
            (name, print_params(value))
            if isinstance(value, frozendict)
            else name
            for name, value in params.items()
    ]
print("All parameters, recursively found:", print_params(params))

print("Sample before:", pure_sample_fn(params))

bptt_length = 3
for epoch in range(101):
    totalloss = 0.
    hc = None
    for start in range(0, len(training_data), bptt_length):
        batch = training_data[start:start+bptt_length]
        loss, new_hc = pure_forward_fn(params, batch, hc)
        if epoch % 50 == 0:
            totalloss += loss.item()
        grads = grad_loss_fn(params, batch, hc)
        params = jax.tree_multimap(update_combiner_01, params, grads)
        hc = new_hc
    if totalloss:
        print("Loss:", totalloss)

print("Sample after:", pure_sample_fn(params))

All parameters, recursively found: [('lstmlm', ['c_0']), ('lstmlm/~/embed', ['embeddings']), ('lstmlm/~/lstmcell', ['bias', 'weight_hh', 'weight_ih'])]
Sample before: [23, 34, 7, 34, 7, 34]
Loss: 23.327383041381836
Loss: 1.491626262664795
Loss: 1.1180200576782227
Sample after: [4, 8, 15, 16, 23, 42]

So, with all this said and done, I hope you feel a little more comfortable with JAX and are ready to start coding your own stuff! As a final bonus question: how do you think JAX does it? The answer might surprise you... (and if you actually watched Skye's talk, you know it: JAX builds up a computation graph internally with placeholder nodes!)

There's plenty more to explore and figure out there still, but with this knowledge you should have a rough idea of what you're looking for in frameworks like flax, trax, and haiku, or, if you don't mind writing pure code all throughout, how to do just that! :) Or, if you have the time and energy, try and implement a “mini-framework” yourself, maybe based on the ideas above—who knows, maybe you'll find just the right abstraction? ;)

Also, if you want to start using JAX for your code, make sure to read up on jit(), vmap(), pmap() and all the other things it offers! There's lots of cool stuff to cover at a future time...

Thanks for reading! Feedback welcome on Twitter: @sjmielke.


Thanks to Igor Babuschkin, Jasmijn Bastings, Anton Belyy, Jason Eisner, Matthew Honnibal, Xiang Lorraine Li, Madison May, Pamela Shapiro, and Suzanna Sia for their feedback, comments, and suggestion on a draft of this post!


Please feel free to cite this post using BibTeX like this:

@misc{
	Mie2020From,
	title={From PyTorch to JAX: towards neural net frameworks that purify stateful code},
	url={https://sjmielke.com/jax-purify.htm},
	author={Sabrina J. Mielke},
	year={2020},
	month={Mar}
}