From PyTorch to JAX: towards neural net frameworks that purify stateful code
2020-03-09
Moving from object-oriented PyTorch- or TF2-code with tape-based backprop to JAX isn't easy---and while running grad() on numpy-oneliners is cool and all, you do wonder... how do I build actual big neural nets? Maybe you decided to look at libraries like flax, trax, or haiku and what you see at least in the ResNet examples looks not too dissimilar from any other framework: define some layers, run some trainers... but what is it that actually happens there? What's the route from these tiny numpy functions to training big hierarchical neural nets? Read on...
