From PyTorch to JAX: towards neural net frameworks that purify stateful code
2020-03-09
Update 2021-07-01: I gave a talk at the Flax/JAX community week largely based on this blogpost---but made a bit more concise and punchy and including an example of flax
' new linen
API! The talk is half Google Slides and half Google Colab notebook, and the recording is on YouTube :)
JAX, Google's now-over-a-year-old Python library for machine learning and other numerical computing describes itself as “Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more”—and while that definition is certainly fitting, it is a bit intimidating. I would describe JAX as numpy, but on GPU, and then move on to the one feature we will be most concerned with in this post: its autodifferentiation capability, i.e., how to get gradients of some loss function your code computes with respect to you input parameters. If you haven't heard of JAX at all, I can recommend Skye Wanderman-Milne's talk at NeurIPS on JAX (or check out the corresponding slides). It's a cool framework with cool ideas!
That said, moving from PyTorch or Tensorflow 2 to JAX is a huge change: the fundamental way we build up computation and, more importantly, backpropagate through it is fundamentally different in the two! PyTorch builds up a graph as you compute the forward pass, and one call to backward()
on some “result” node then augments each intermediate node in the graph with the gradient of the result node with respect to that intermediate node. JAX on the other hand makes you express your computation as a Python function, and by transforming it with grad()
gives you a gradient function that you can evaluate like your computation function—but instead of the output it gives you the gradient of the output with respect to (by default) the first parameter that your function took as input:
This has consequences for how you write code and build up models in both frameworks, of course. So when you're used to tape-based auto-differentiation and working with stateful objects in PyTorch or Tensorflow 2, coming to JAX may be quite a shock—and while running grad()
on numpy-oneliners like the one above (which we will actually run later below) is cool and all, you wonder what a minimal example for, say, a language model would look like (language models aren't quite as straightforward to implement as ResNets: they can have dynamic structures that aren't always nicely divisible into “layers”).
Maybe you decided to look at libraries like flax
, trax
, or haiku
and what you see at least in the ResNet examples looks not too dissimilar from any other framework: define some layers, run some trainers... but what is it that actually happens there? What's the route from these tiny numpy functions to training big hierarchical neural nets?
That's the niche this post is trying to fill. We will:
- quickly recap a stateful LSTM-LM implementation in a tape-based gradient framework, specifically PyTorch,
- see how PyTorch-style coding relies on mutating state, learn about mutation-free pure functions and build (pure) zappy one-liners in JAX,
- step-by-step go from individual parameters to medium-size modules by registering them as pytree nodes,
- combat growing pains by building fancy scaffolding, and controlling context to extract initialized parameters purify functions and
- realize that we could get that easily in a framework like DeepMind's
haiku
using itstransform
mechanism.
Things we will not do in this tutorial: build state-of-the-art models or elegant, “idiomatic” codebases—we just try to build minimal examples that exhibit the complexity we are looking for. We also won't cover batching (which is easy with vmap()
), distributing (pmap()
), and XLA-compiling your code (jit()
) in JAX—all of which are really cool features, but beside the point.
Let's get started!
1. A LSTM-LM in PyTorch
To make sure we're on the same page, let's implement the language model I want to work towards in PyTorch. To keep the comparison straightforward, we will implement things from scratch as much as possible in all three approaches. Let's start with an LSTMCell that holds some parameters:
import torch class LSTMCell(torch.nn.Module): def __init__(self, in_dim, out_dim): super(LSTMCell, self).__init__() self.weight_ih = torch.nn.Parameter(torch.rand(4*out_dim, in_dim)) self.weight_hh = torch.nn.Parameter(torch.rand(4*out_dim, out_dim)) self.bias = torch.nn.Parameter(torch.zeros(4*out_dim,)) def forward(self, inputs, h, c): ifgo = self.weight_ih @ inputs + self.weight_hh @ h + self.bias i, f, g, o = torch.chunk(ifgo, 4) i = torch.sigmoid(i) f = torch.sigmoid(f) g = torch.tanh(g) o = torch.sigmoid(o) new_c = f * c + i * g new_h = o * torch.tanh(new_c) return (new_h, new_c)
Next, build a super simple 1-layer LSTM language model using this cell. Note that to keep the example simple, we will just use a simple matrix for embeddings. This and the learned (h,c)0 will demonstrate how individual paramaters are registered in our solutions.
class LSTMLM(torch.nn.Module): def __init__(self, vocab_size, dim=17): super().__init__() self.cell = LSTMCell(dim, dim) self.embeddings = torch.nn.Parameter(torch.rand(vocab_size, dim)) self.c_0 = torch.nn.Parameter(torch.zeros(dim)) @property def hc_0(self): return (torch.tanh(self.c_0), self.c_0) def forward(self, seq, hc): loss = torch.tensor(0.) for idx in seq: loss -= torch.log_softmax(self.embeddings @ hc[0], dim=-1)[idx] hc = self.cell(self.embeddings[idx,:], *hc) return loss, hc def greedy_argmax(self, hc, length=6): with torch.no_grad(): idxs = [] for i in range(length): idx = torch.argmax(self.embeddings @ hc[0]) idxs.append(idx.item()) hc = self.cell(self.embeddings[idx,:], *hc) return idxs
To demonstrate that it works, let's train:
torch.manual_seed(0) # As training data, we will have indices of words/wordpieces/characters, # we just assume they are tokenized and integerized (toy example obviously). import jax.numpy as jnp vocab_size = 43 # prime trick! :) training_data = jnp.array([4, 8, 15, 16, 23, 42]) lm = LSTMLM(vocab_size=vocab_size) print("Sample before:", lm.greedy_argmax(lm.hc_0)) bptt_length = 3 # to illustrate hc.detach-ing for epoch in range(101): hc = lm.hc_0 totalloss = 0. for start in range(0, len(training_data), bptt_length): batch = training_data[start:start+bptt_length] loss, (h, c) = lm(batch, hc) hc = (h.detach(), c.detach()) if epoch % 50 == 0: totalloss += loss.item() loss.backward() for name, param in lm.named_parameters(): if param.grad is not None: param.data -= 0.1 * param.grad del param.grad if totalloss: print("Loss:", totalloss) print("Sample after:", lm.greedy_argmax(lm.hc_0)) Sample before: [42, 34, 34, 34, 34, 34] Loss: 25.953862190246582 Loss: 3.7642268538475037 Loss: 1.9537211656570435 Sample after: [4, 8, 15, 16, 23, 42]
So, this is all nice and perhaps more intuitive than old graph-mode Theano or Tensorflow, but it comes with some annoying gotchas. Even while writing this up, I struggled with making sure I detach
-ed in the right places and marked the right things as Parameter
s to make sure just the right amount of nodes in the computation graph (which PyTorch assembles through our statements in the background) are marked as intermediate and cleaned up at the right time.
2. Pure functions
To understand how JAX handles this issue, we need to understand the concept of pure functions. If you've done some functional programming before, you might be familiar with that concept: a pure function is like a function or formula in math. It defines how an output value is obtained from some input values. What's important is that it has no “side effects”: no part of the function should access or even mutate any global state.
The way we wrote our code in PyTorch was very much stateful and full of mutating state, making reasoning about it and optimizing it a bit tricky. JAX therefore chooses to constrain a programmer to pure functions that don't do any of that.
But let's look at some examples of pure functions before we dive into JAX. To confirm that a function is pure, in a nutshell, these criteria have to hold:
- It shouldn't matter when and in what context you execute the function—as long as the inputs are the same, the outputs should be the same.
- Whether we executed the function zero, one, or many times should be absolutely impossible to discern after the fact.
See how all the impure functions below violate at least one of these constraints:
import random import time nr_executions = 0 def pure_fn_1(x): return 2 * x def pure_fn_2(xs): ys = [] for x in xs: # Mutating stateful variables *inside* the function is fine! ys.append(2 * x) return ys def impure_fn_1(xs): # Mutating arguments has lasting consequences outside the function! :( xs.append(sum(xs)) return xs def impure_fn_2(x): # Very obviously mutating global state is bad... global nr_executions nr_executions += 1 return 2 * x def impure_fn_3(x): # ...but just accessing it is, too, because now the function depends on the # execution context! return nr_executions * x def impure_fn_4(x): # Things like IO are classic examples of impurity. # All three of the following lines are violations of purity: print("Hello!") user_input = input() execution_time = time.time() return 2 * x def impure_fn_5(x): # Which constraint does this violate? Both, actually! You access the current # state of randomness *and* advance the number generator! p = random.random() return p * x
Let's see a pure function that JAX operates on: the example from the intro figure.
# (almost) 1-D linear regression def f(w, x): return w * x print(f(13., 42.))
546.0
So far, so uneventful. What JAX now allows you to do is to take a function like this and transform it into a function that instead of returning the result, returns the gradient of the result with respect to (by default) the first parameter!
import jax import jax.numpy as jnp # Gradient: with respect to weights! JAX uses the first argument by default. df_dw = jax.grad(f) def manual_df_dw(w, x): return x assert df_dw(13., 42.) == manual_df_dw(13., 42.) print(df_dw(13., 42.))
42.0
Everything up to here you probably have seen in the JAX README and it kinda makes sense. But how do we get from here to big modules like the one in our PyTorch code?
First, let's add a bias term and try to wrap the 1-D linear regressor we get into an object like we're used to: a kind of LinearRegressor
“layer“:
class LinearRegressor(): def __init__(self, w, b): self.w = w self.b = b def predict(self, x): return self.w * x + self.b def rms(self, xs: jnp.ndarray, ys: jnp.ndarray): return jnp.sqrt(jnp.sum(jnp.square(self.w * xs + self.b - ys))) my_regressor = LinearRegressor(13., 0.) # A kind of loss fuction, used for training xs = jnp.array([42.0]) ys = jnp.array([500.0]) print(my_regressor.rms(xs, ys)) # Prediction for test data print(my_regressor.predict(42.))
46.0 546.0
So far, so good. So how do we get gradients to train? We would need a pure function that has our parameters as arguments somewhere, maybe like this:
def loss_fn(w, b, xs, ys): my_regressor = LinearRegressor(w, b) return my_regressor.rms(xs=xs, ys=ys) # We use argnums=(0, 1) to tell JAX to give us # gradients wrt first and second parameter. grad_fn = jax.grad(loss_fn, argnums=(0, 1)) print(loss_fn(13., 0., xs, ys)) print(grad_fn(13., 0., xs, ys))
46.0 (DeviceArray(42., dtype=float32), DeviceArray(1., dtype=float32))
Convince yourself that that's true :)
Now, this is workable, but clearly enumerating all parameters in the head of loss_fn
isn't feasible.
Luckily, JAX is not just comfortable differentiating with respect to scalars, vectors, and matrices, but also with respect to a number of tree-like data structures that it calls pytrees—and they include python dicts:
def loss_fn(params, xs, ys): my_regressor = LinearRegressor(params['w'], params['b']) return my_regressor.rms(xs=xs, ys=ys) grad_fn = jax.grad(loss_fn) print(loss_fn({'w': 13., 'b': 0.}, xs, ys)) print(grad_fn({'w': 13., 'b': 0.}, xs, ys))
46.0 {'b': DeviceArray(1., dtype=float32), 'w': DeviceArray(42., dtype=float32)}
So this already looks nicer! We could write a training loop like this:
params = {'w': 13., 'b': 0.} for _ in range(15): print(loss_fn(params, xs, ys)) grads = grad_fn(params, xs, ys) for name in params.keys(): params[name] -= 0.002 * grads[name] # Now, predict: LinearRegressor(params['w'], params['b']).predict(42.)
46.0 42.47003 38.940002 35.410034 31.880066 28.350098 24.820068 21.2901 17.760132 14.230164 10.700165 7.170166 3.6401978 0.110198975 3.4197998 DeviceArray(500.1102, dtype=float32)
Note that we can already make use of a bit more JAX helpers for the updating itself: since params and grads have the same (tree-like) structure, we can imagine laying them on top and creating a new tree whose values everywhere are a “combination” of the two trees like this:
def update_combiner(param, grad, lr=0.002): return param - lr * grad params = jax.tree_multimap(update_combiner, params, grads) # instead of: # for name in params.keys(): # params[name] -= 0.1 * grads[name]
3. A lacking impromptu solution: registering classes as custom pytree types
So, it works. But going back and forth between our object and the params
dict is a bit annoying. One thing we can do to simplify the process is allow JAX to see our so-far rather opaque LinearRegressor
class as a data structure, allowing it to be used in place of the dict params
we have!
For this we will need to tell JAX how one can break our class down into a list of parameters and auxiliary information (called flattening) and how then, it can reassemble the class with perhaps changed parameters and that auxiliary information (called unflattening):
def flatten_linear_regressor(regressor): leaves = (regressor.w, regressor.b) aux = None # we don't need auxiliary information for this simple class return (leaves, aux) # careful, switched argument order! (unfortunate baggage from the past...) def unflatten_linear_regressor(_aux, leaves): w, b = leaves return LinearRegressor(w, b) jax.tree_util.register_pytree_node( LinearRegressor, flatten_linear_regressor, unflatten_linear_regressor, )
Now we can use our regressor throughout, making for a very easy loss_fn
:
def loss_fn(regressor, xs, ys): return regressor.rms(xs=xs, ys=ys) grad_fn = jax.grad(loss_fn) print(loss_fn(LinearRegressor(w=13., b=0.), xs, ys))
46.0
Now what do you think the function grad_fn
returns? It used to be a dict of gradients in the shape of the params
dict, but now...
print(grad_fn(LinearRegressor(w=13., b=0.), xs, ys))
<__main__.LinearRegressor object at 0x7f4ea586b128>
...it is a LinearRegressor
object that has gradients where the params used to be! Again, we can use jax.tree_util.tree_multimap
to combine:
model = LinearRegressor(w=13., b=0.) for _ in range(15): print(loss_fn(model, xs, ys)) grads = grad_fn(model, xs, ys) model = jax.tree_multimap(update_combiner, model, grads) # Now, predict: model.predict(42.)
46.0 42.47003 38.940002 35.410034 31.880066 28.350098 24.820068 21.2901 17.760132 14.230164 10.700165 7.170166 3.6401978 0.110198975 3.4197998 DeviceArray(500.1102, dtype=float32)
So trying to make your modules and your models all flatten
- and unflatten
-able is one way you can avoid param dict handling (you're essentially writing it once for your flatten
and unflatten
and then let JAX call them).
The downside is that this requires you to write these flattening/unflattening functions for every module (though much of that can certainly be re-used in a base class like nn.Module
). Even for small things like our LSTM-LM from above this solution looks plenty ugly:
class PytreeLSTMCell(): def __init__(self, weight_ih, weight_hh, bias): self.weight_ih = weight_ih self.weight_hh = weight_hh self.bias = bias def __call__(self, inputs, h, c): ifgo = self.weight_ih @ inputs + self.weight_hh @ h + self.bias i, f, g, o = jnp.split(ifgo, indices_or_sections=4, axis=-1) i = jax.nn.sigmoid(i) f = jax.nn.sigmoid(f) g = jnp.tanh(g) o = jax.nn.sigmoid(o) new_c = f * c + i * g new_h = o * jnp.tanh(new_c) return (new_h, new_c) jax.tree_util.register_pytree_node( PytreeLSTMCell, lambda c: ((c.weight_ih, c.weight_hh, c.bias), None), lambda _, ws: PytreeLSTMCell(*ws), ) class PytreeLSTMLM(): def __init__(self, cell, embeddings, c_0): self.cell = cell self.embeddings = embeddings self.c_0 = c_0 @property def hc_0(self): return (jnp.tanh(self.c_0), self.c_0) @jax.jit # jit compiles with XLA, so we are a lot faster (try it without!). def forward(self, seq, hc): loss = 0. for idx in seq: loss -= jax.nn.log_softmax(self.embeddings @ hc[0])[idx] hc = self.cell(self.embeddings[idx,:], *hc) return loss, hc def greedy_argmax(self, hc, length=6): idxs = [] for i in range(length): idx = jnp.argmax(self.embeddings @ hc[0]) idxs.append(int(idx)) hc = self.cell(self.embeddings[idx,:], *hc) return idxs # These two functions are just a whole lot of unreadable YIKES def flatten_whole_lstmlm(lm): flat_cell_weights, flat_cell_aux = jax.tree_util.tree_flatten(lm.cell) return tuple(flat_cell_weights) + (lm.embeddings, lm.c_0), flat_cell_aux def unflatten_whole_lstmlm(aux, weights): flat_cell_weights = weights[:-2] embeddings, c_0 = weights[-2:] cell = jax.tree_util.tree_unflatten(aux, flat_cell_weights) return PytreeLSTMLM(cell, embeddings, c_0) jax.tree_util.register_pytree_node( PytreeLSTMLM, flatten_whole_lstmlm, unflatten_whole_lstmlm, )
Look at that. Nasty.
(Apart from the horrors in the handler, you might've noticed that I added the @jax.jit
annotation to the forward
method just to make it run a bit faster—this, like grad()
only works cause we made our class transparent to JAX—but it doesn't affect functionality, as you can see when you try to remove it. We won't talk about it more in this tutorial.)
But just to show you that it works, let's train it like we trained the PyTorch model. To do that, we will need initial parameters, and that brings us to another initially frightening aspect of JAX: you always have to specify where your randomness comes from!
Remember that we said that randomness is also a source of impurity. But, really, in a computer, you never get real randomness, you always get pseudo-randomness: random looking numbers generated from a seed. So, if you have a seed, generating pseudo-random numbers from it is very much a pure operation. This means that in JAX, we will always need to give every function that samples something (e.g., in initializing parameters) a seed to draw randomness from. Usually you'd nicely thread these through your program, but to keep it short we'll just hardcode them here. (Later you'll see us using them a bit more properly.)
hid_dim = 17 lm = PytreeLSTMLM( PytreeLSTMCell( jax.random.uniform(jax.random.PRNGKey(1234), (4*hid_dim, hid_dim)), jax.random.uniform(jax.random.PRNGKey(4321), (4*hid_dim, hid_dim)), jnp.zeros((4*hid_dim,)), ), jax.random.uniform(jax.random.PRNGKey(123), (vocab_size, hid_dim)), jnp.zeros((hid_dim,)), ) # We are using a different LR, i.e., a new optimizer/combiner here update_combiner_01 = lambda p, g: update_combiner(p, g, lr=0.1)
To make sure JAX takes note of us constructing (h0,c0) from c0, which we deem a parameter, that too has to happen inside the “pure” function that we want to call jax.grad
on, so we end up with this:
def pure_loss_fn(lm, seq, hc): if hc is None: hc = lm.hc_0 loss, _ = lm.forward(seq, hc) return loss grad_fn = jax.grad(pure_loss_fn) print("Sample before:", lm.greedy_argmax(lm.hc_0)) bptt_length = 3 for epoch in range(101): totalloss = 0. hc = None for start in range(0, len(training_data), bptt_length): batch = training_data[start:start+bptt_length] loss, new_hc = lm.forward(batch, hc if hc else lm.hc_0) if epoch % 50 == 0: totalloss += loss.item() grad_lm = grad_fn(lm, batch, hc) lm = jax.tree_multimap(update_combiner_01, lm, grad_lm) hc = new_hc if totalloss: print("Loss:", totalloss) print("Sample after:", lm.greedy_argmax(lm.hc_0))
Sample before: [0, 25, 25, 25, 25, 25] Loss: 26.58103370666504 Loss: 4.404336929321289 Loss: 2.6979217529296875 Sample after: [4, 8, 15, 16, 23, 42]
Well, it works! But... for one, writing the handlers was rather nasty. We could imagine a nn.Module
-like class like in PyTorch that automatically builds these handlers from its parameters and submodules, if we give it some method to register and keep track of them—hold that thought for later!—this would allow us to write code that was a bit closer to PyTorch.
Note: Since I wrote this blogpost, Sasha Rush approached me with a suggestion on how to turn this idea into a fully-fledged framework and I figured I'd give it a shot and implement his idea in JAX! The result, parallax
, takes care of all the flattening/unflattening and registers parameters not as strings like flax
or haiku
, but as normal ndarray
s using nice type annotations. Check it out and contribute! :)
However, there's also another annoyance with this code: we had to initialize all our parameters outside of the classes, far away from where we use them. That's really undesirable! Imagine we want to exchange our LSTM cell for a GRU cell. We would have to look up how exactly the GRU cell is implemented and what parameters of which shapes it uses, and then rewrite our initialization block to cater to that.
In contrast, what we are going to work our way towards is a framework that allows us to specify initializations for parameters inside the modules they are used in themselves, and, more impressively yet, turn stateful looking interactions with these registered parameters into pure functions that JAX can easily operate on!
4. Fancy objects manage parameters
To see the basic idea in action, consider the good old w * x
example:
On the left is the class we'd like to write PyTorch-style, and on the right we have a class that houses a pure function like we've been writing manually before.
The key insight is that we can very easily generate that “purified” function by controlling the context—specifically by setting self.w
to reference the tensor w
that was given as input, thus making sure that the formerly purity-violating access now only accesses inputs of the function, making the entire thing pure. What's important is that we don't need to know what f
is doing with x
or w
when we write purified_f
: all the logic of f
is still happening, just now with a controlled context.
Let's implement a “purifying” regressor that does exactly this. We have two parameters, w
and b
, so we're going back to the parameter dictionary approach to keep them together. We will also wrap the problematic access (the thick arrow) in a method call get_param
instead of referencing self.w
(or, really self.params
) immediately. That (for now) will just make sure we are actually running in “pure mode” and throw an error otherwise.
class PurifyingRegressor(): # This is the function the user wants to write. def impure_user_fn(self, x): w = self.get_param('w') b = self.get_param('b') return w * x + b # This is the wrapping/mode logic. params = {} is_running_pure = False def get_param(self, name): if not self.is_running_pure: raise Exception("We can only call this when wrapped!") else: return self.params[name] # This is the function that we want JAX to use! # Note how it defers all calculation and logic to # the user-written and just provides *context*! def pure_wrapped_user_fn(self, params, *args): self.is_running_pure = True self.params = params result = self.impure_user_fn(*args) self.is_running_pure = False return result
Convince yourself that it works:
pure_predict_fn = PurifyingRegressor().pure_wrapped_user_fn(*args) # Show that it's equal to this manually purified function: manually_pure_predict_fn = lambda params, x: params['w'] * x + params['b'] args = ({'w': 13., 'b': 0.}, 42.) assert pure_predict_fn(*args) == manually_pure_predict_fn(*args)
If that all makes sense to you, let's move on to introducing a second mode: an initialization mode! The reason is that remembering all parameters outside of the class and the code like we did above is really annoying. It would be much nicer if we could give an initialization for a parameter as we use it (this will also allow us to do some shape inference, e.g., only specify the output dimensionality of a linear layer and infer the input dimension from the incoming data).
To do just that, we will extend get_param
to also take an initializer
argument (and extract all this magic that we're writing into a simple nn.module
-like class the others will be able to inherit from):
class FlatFancyNNModule(): def __init__(self, name): self.name = name self.params = None self.are_we_pure = False self.are_we_initializing = False # This is where, depending on mode, we'll give the user-written # impure function different things! def get_param(self, name, shape, initializer): if self.are_we_initializing: if name not in self.params: self.params[name] = initializer(shape) return self.params[name] elif self.are_we_pure: return self.params[name] else: raise Exception("Can't access parameters outside of context!") # This function calls the initializers to give us an initial params dict. def initial_params(self, method, *args): self.are_we_initializing = True self.params = {} method(*args) self.are_we_initializing = False return self.params # This function returns a pure version of the function. def purify_method(self, method): def pure_method(params, *args): self.are_we_pure = True self.params = params result = method(*args) self.are_we_pure = False self.params = None return result return pure_method
So, to recap, there are now two “modes” that an impure function we wrote can be called in:
- initialization mode, where accessing a parameter calls the initializer, the result is a
params
dict with all these initial values - pure running mode, where we simulate purity by making sure the “stateful” class member
params
is actually the pure argument that is fed in.
Let's test it on our simple example! An implementation of a user-written module will look like this:
class FancyLinearRegressor(FlatFancyNNModule): def __init__(self, name="linreg"): super().__init__(name=name) def predict(self, x): # Our "initializers" are a bit simple here. w = self.get_param('w', (0,), lambda _shape: 13.) b = self.get_param('b', (0,), lambda _shape: 0.) return w * x + b def rms(self, xs, ys): # Same as before. w = self.get_param('w', (0,), lambda _shape: 13.) b = self.get_param('b', (0,), lambda _shape: 0.) return jnp.sqrt(jnp.sum(jnp.square(w * xs + b - ys)))
And this is how we would transform it:
my_regressor = FancyLinearRegressor() sample_input = 999. # a placeholder, necessary so the method can be run completely params = my_regressor.initial_params(my_regressor.predict, sample_input) print("Params:", params) pure_predict_fn = my_regressor.purify_method(my_regressor.predict) pure_predict_fn(params, 42.)
Params: {'w': 13.0, 'b': 0.0} 546.0
Nice! And of course, we can also train:
pure_rms_fn = my_regressor.purify_method(my_regressor.rms) grad_rms_fn = jax.grad(pure_rms_fn) params = my_regressor.initial_params(my_regressor.predict, sample_input) for _ in range(15): print(pure_rms_fn(params, xs, ys)) grads = grad_rms_fn(params, xs, ys) params = jax.tree_multimap(update_combiner, params, grads) pure_predict_fn(params, 42.)
46.0 42.47003 38.940002 35.410034 31.880066 28.350098 24.820068 21.2901 17.760132 14.230164 10.700165 7.170166 3.6401978 0.110198975 3.4197998 DeviceArray(500.1102, dtype=float32)
And it works! So now I hope you really want to move on and see our LSTM-LM implemented—but if you were to do that right now, you'd notice an issue: our implementation only works on parameters that are directly requested inside the module itself! It doesn't support submodules—but we definitely want those for our LM, which will both own some parameters and own a LSTMCell
class that owns its own. So, to make sure that submodules of a given object that are also FancyNNModule
descendants follow suit, we'll update our implementation also keep a list of submodules
around and recurse through all those on all mode changes:
class FancyNNModule(): def __init__(self, name): self.name = name self.params = None self.are_we_pure = False self.are_we_initializing = False # Implementations should register their submodules here! self.submodules = {} # This is where, depending on mode, we'll give the user-written # impure function different things! def get_param(self, name, shape, initializer): if self.are_we_initializing: if name not in self.params: self.params[name] = initializer(shape) return self.params[name] elif self.are_we_pure: return self.params[name] else: raise Exception("Can't access parameters outside of context!") def set_are_we_pure(self, ispure): if ispure: self.params = {} self.are_we_pure = ispure for sm in self.submodules.values(): sm.set_are_we_pure(ispure) if not ispure: self.params = None def set_are_we_initializing(self, isinit): if isinit: self.params = {} self.are_we_initializing = isinit for sm in self.submodules.values(): sm.set_are_we_initializing(isinit) # This method gathers params from this module and all submodules. def gather_params(self): params = self.params for sm_name, sm in self.submodules.items(): for name, value in sm.gather_params().items(): params[sm_name + "/" + name] = value return params # This method spreads out params into self and all submodules. def spread_params(self, params): for name, value in params.items(): path = name.split("/") if len(path) == 1: self.params[name] = value else: self.submodules[path[0]].spread_params({"/".join(path[1:]): value}) # This function calls the initializers to give us an initial params dict. def initial_params(self, method, *args): self.set_are_we_initializing(True) method(*args) self.set_are_we_initializing(False) params = self.gather_params() return params # This function returns a pure version of the function. def purify_method(self, method): def pure_method(params, *args): self.set_are_we_pure(True) self.spread_params(params) result = method(*args) self.set_are_we_pure(False) return result return pure_method
Note that this is still a very naive implementation that will not work for general cases, but it will serve to illustrate the rough idea—a real implementation would properly manage a frame stack.
But, foregoing that, let's finally implement the LSTM-LM for the third time (this time, we'll use jax.random.split
for threading randomness, yes through an ugly global variable... but shhhh...):
random_key = jax.random.PRNGKey(0) def unif_initializer(shape): global random_key sample_key, random_key = jax.random.split(random_key) return jax.random.uniform(sample_key, shape) class FancyLSTMCell(FancyNNModule): def __init__(self, hid_dim, name="lstmcell"): super().__init__(name=name) self.hid_dim = hid_dim def __call__(self, inputs, h, c): weight_ih = self.get_param("weight_ih", (4*self.hid_dim, self.hid_dim), unif_initializer) weight_hh = self.get_param("weight_hh", (4*self.hid_dim, self.hid_dim), unif_initializer) bias = self.get_param("bias", (4*self.hid_dim,), lambda shape: jnp.zeros(shape)) ifgo = weight_ih @ inputs + weight_hh @ h + bias i, f, g, o = jnp.split(ifgo, indices_or_sections=4, axis=-1) i = jax.nn.sigmoid(i) f = jax.nn.sigmoid(f) g = jnp.tanh(g) o = jax.nn.sigmoid(o) new_c = f * c + i * g new_h = o * jnp.tanh(new_c) return (new_h, new_c) class FancyLSTMLM(FancyNNModule): def __init__(self, vocab_size, dim, name="lstmlm"): super().__init__(name=name) self.vocab_size = vocab_size self.dim = dim # Now create a submodule and register it! self.cell = FancyLSTMCell(dim) self.submodules[self.cell.name] = self.cell @property def hc_0(self): _c_0 = self.get_param("c_0", (self.dim,), lambda shape: jnp.zeros(shape)) return (jnp.tanh(_c_0), _c_0) def forward(self, seq, hc): loss = 0. embeddings = self.get_param("embeddings", (self.vocab_size, self.dim), unif_initializer) for idx in seq: loss -= jax.nn.log_softmax(embeddings @ hc[0])[idx] hc = self.cell(embeddings[idx,:], *hc) return loss, hc def greedy_argmax(self, hc, length=6): idxs = [] embeddings = self.get_param("embeddings", (self.vocab_size, self.dim), unif_initializer) for i in range(length): idx = jnp.argmax(embeddings @ hc[0]) idxs.append(int(idx)) hc = self.cell(embeddings[idx,:], *hc) return idxs
Again, we have to make sure we “catch” the (h,c)0 construction in our transformation process, so we define three pure functions and start training:
lm = FancyLSTMLM(vocab_size, 17) # Since jitting and our mutable-state-galore "tracer" don't play nice, only # call jit on the outermost level: the actually pure function! pure_sample_fn = lm.purify_method( lambda: lm.greedy_argmax(lm.hc_0)) pure_forward_fn = jax.jit(lm.purify_method( lambda seq, hc: lm.forward(seq, hc if hc else lm.hc_0))) grad_loss_fn = jax.jit(jax.grad(lm.purify_method( lambda seq, hc: lm.forward(seq, hc if hc else lm.hc_0)[0]))) params = lm.initial_params(lambda: lm.forward(jnp.array([0]), lm.hc_0)) print("All parameters, recursively found:", list(params.keys())) print("Sample before:", pure_sample_fn(params)) bptt_length = 3 for epoch in range(101): totalloss = 0. hc = None for start in range(0, len(training_data), bptt_length): batch = training_data[start:start+bptt_length] loss, new_hc = pure_forward_fn(params, batch, hc) if epoch % 50 == 0: totalloss += loss.item() grads = grad_loss_fn(params, batch, hc) params = jax.tree_multimap(update_combiner_01, params, grads) hc = new_hc if totalloss: print("Loss:", totalloss) print("Sample after:", pure_sample_fn(params))
All parameters, recursively found: ['c_0', 'embeddings', 'lstmcell/weight_ih', 'lstmcell/weight_hh', 'lstmcell/bias'] Sample before: [0, 20, 20, 20, 20, 20] Loss: 24.401604652404785 Loss: 3.0977354049682617 Loss: 1.744471549987793 Sample after: [4, 8, 15, 16, 23, 42]
5. We could've used haiku.transform
all along
Here's the plot twist: what we've done is basically implement a poor version of haiku
's transform()
! The function haiku.transform(impure_fn)
return both an init_fn
and an apply_fn
. The init_fn
corresponds to our initial_params()
method, the apply_fn
is the “purified” version of the impure_fn
we provide.
So, without further ado, let's do one last LSTM implementation, using haiku.transform()
and a mixture of self-allocated parameters and haiku
layers (which, again, don't take as a particularly pretty canonical example):
import haiku as hk class HaikuLSTMCell(hk.Module): def __init__(self, in_dim, out_dim, name=None): super().__init__(name=name or "lstmcell") self.in_dim = in_dim self.out_dim = out_dim def __call__(self, inputs, h, c): weight_ih = hk.get_parameter("weight_ih", (4*self.out_dim, self.in_dim), init=hk.initializers.UniformScaling()) weight_hh = hk.get_parameter("weight_hh", (4*self.out_dim, self.out_dim), init=hk.initializers.UniformScaling()) bias = hk.get_parameter("bias", (4*self.out_dim,), init=hk.initializers.Constant(0.0)) ifgo = weight_ih @ inputs + weight_hh @ h + bias i, f, g, o = jnp.split(ifgo, indices_or_sections=4, axis=-1) i = jax.nn.sigmoid(i) f = jax.nn.sigmoid(f) g = jnp.tanh(g) o = jax.nn.sigmoid(o) new_c = f * c + i * g new_h = o * jnp.tanh(new_c) return (new_h, new_c) class HaikuLSTMLM(hk.Module): def __init__(self, vocab_size, dim, name=None): super().__init__(name=name or "lstmlm") _c0 = hk.get_parameter("c_0", (dim,), init=hk.initializers.TruncatedNormal(stddev=0.1)) self.hc_0 = (jnp.tanh(_c0), _c0) self.embeddings = hk.Embed(vocab_size, dim) self.cell = HaikuLSTMCell(dim, dim) # @jax.jit def forward(self, seq, hc): loss = 0. for idx in seq: loss -= jax.nn.log_softmax(self.embeddings.embeddings @ hc[0])[idx] hc = self.cell(self.embeddings(idx), *hc) return loss, hc def greedy_argmax(self, hc, length=6): idxs = [] for i in range(length): idx = jnp.argmax(self.embeddings.embeddings @ hc[0]) idxs.append(int(idx)) hc = self.cell(self.embeddings(idx), *hc) return idxs
Note that because Haiku
mandates that all __init__
calls of haiku.Module
s happen inside a transform
, we just go ahead and also allocate our parameters and haiku
modules right there—as long as that __init__
is part of the impure function you transform in the end, that's no problem:
def impure_sample_fn(): lm = HaikuLSTMLM(vocab_size, 17) return lm.greedy_argmax(lm.hc_0) def impure_forward_fn(seq, hc): lm = HaikuLSTMLM(vocab_size, 17) return lm.forward(seq, hc if hc else lm.hc_0) _, pure_sample_fn = hk.transform(impure_sample_fn) init_fn, nojit_pure_forward_fn = hk.transform(impure_forward_fn) _, nojit_pure_loss_fn = hk.transform(lambda *args: impure_forward_fn(*args)[0]) pure_forward_fn = jax.jit(nojit_pure_forward_fn) pure_loss_fn = jax.jit(nojit_pure_loss_fn) grad_loss_fn = jax.jit(jax.grad(nojit_pure_loss_fn)) rng = jax.random.PRNGKey(0) # Haiku actually manages the random number generator :) params = init_fn(rng, jnp.array([0]), None) from haiku._src.data_structures import frozendict def print_params(params): return [ (name, print_params(value)) if isinstance(value, frozendict) else name for name, value in params.items() ] print("All parameters, recursively found:", print_params(params)) print("Sample before:", pure_sample_fn(params)) bptt_length = 3 for epoch in range(101): totalloss = 0. hc = None for start in range(0, len(training_data), bptt_length): batch = training_data[start:start+bptt_length] loss, new_hc = pure_forward_fn(params, batch, hc) if epoch % 50 == 0: totalloss += loss.item() grads = grad_loss_fn(params, batch, hc) params = jax.tree_multimap(update_combiner_01, params, grads) hc = new_hc if totalloss: print("Loss:", totalloss) print("Sample after:", pure_sample_fn(params))
All parameters, recursively found: [('lstmlm', ['c_0']), ('lstmlm/~/embed', ['embeddings']), ('lstmlm/~/lstmcell', ['bias', 'weight_hh', 'weight_ih'])] Sample before: [23, 34, 7, 34, 7, 34] Loss: 23.327383041381836 Loss: 1.491626262664795 Loss: 1.1180200576782227 Sample after: [4, 8, 15, 16, 23, 42]
So, with all this said and done, I hope you feel a little more comfortable with JAX and are ready to start coding your own stuff! As a final bonus question: how do you think JAX does it? The answer might surprise you... (and if you actually watched Skye's talk, you know it: JAX builds up a computation graph internally with placeholder nodes!)
There's plenty more to explore and figure out there still, but with this knowledge you should have a rough idea of what you're looking for in frameworks like flax
, trax
, and haiku
, or, if you don't mind writing pure code all throughout, how to do just that! :) Or, if you have the time and energy, try and implement a “mini-framework” yourself, maybe based on the ideas above—who knows, maybe you'll find just the right abstraction? ;)
Also, if you want to start using JAX for your code, make sure to read up on jit()
, vmap()
, pmap()
and all the other things it offers! There's lots of cool stuff to cover at a future time...
Thanks for reading! Feedback welcome on Twitter: @sjmielke.
Thanks to Igor Babuschkin, Jasmijn Bastings, Anton Belyy, Jason Eisner, Matthew Honnibal, Xiang Lorraine Li, Madison May, Pamela Shapiro, and Suzanna Sia for their feedback, comments, and suggestion on a draft of this post!
Please feel free to cite this post using BibTeX like this:
@misc{ Mie2020From, title={From PyTorch to JAX: towards neural net frameworks that purify stateful code}, url={https://sjmielke.com/jax-purify.htm}, author={Sabrina J. Mielke}, year={2020}, month={Mar} }