explaingit

vorhersager/deep-learning-jax

11Jupyter NotebookAudience · researcherComplexity · 4/5ActiveSetup · hard

TLDR

Eleven graduate-level Jupyter notebooks teaching deep learning from autodiff up to GPT-2 fine-tuning, RLHF, and explainability using JAX, Flax, and Equinox.

Mindmap

mindmap
  root((deep-learning-jax))
    Inputs
      Math foundations
      Sample datasets
      Pretrained weights
    Outputs
      Trained models
      Loss landscape plots
      Generated images and text
    Use Cases
      Self study deep learning
      Teach a graduate course
      Reimplement architectures
    Tech Stack
      JAX
      Flax
      Equinox
      Optax

Things people build with this

USE CASE 1

Work through 11 tutorials covering autodiff, MLPs, CNNs, RNNs, and transformers

USE CASE 2

Build a Nano GPT from scratch with causal self attention in JAX

USE CASE 3

Train a Soft Actor Critic racing agent with prioritized experience replay

USE CASE 4

Reproduce a CIFAR-10 backdoor attack and detect it with Integrated Gradients

Tech stack

JAXFlaxEquinoxOptaxPython

Getting it running

Difficulty · hard Time to first run · 1day+

Later tutorials cover GANs, diffusion, RL, and distributed GPT-2 training, which need GPU resources to run end to end.

In plain English

This repository is a series of Jupyter notebooks that teach deep learning from the math up, using the JAX numerical computing library and a few JAX based add ons such as Flax and Equinox. The author, John Sipple, originally wrote the material for graduate level instruction. The goal stated in the foreword is to show readers what is actually happening inside modern AI systems, rather than only how to call functions in a high level framework. The curriculum is laid out as 11 tutorials that build on each other. Tutorial 1 covers the mathematical groundwork, comparing TensorFlow's eager execution with JAX's transformations and covering automatic differentiation, Jacobians, cross entropy, and KL divergence. Tutorial 2 moves to linear and ridge regression solved by hand, then builds a multilayer perceptron and implements backpropagation manually before checking it against JAX's autodiff. Tutorial 3 implements SGD, Nesterov momentum, RMSProp, and Adam from scratch and shows 3D surfaces of how each one moves through a loss landscape with saddle points and local minima. Tutorials 4 to 6 cover the classic neural network architectures. Tutorial 4 uses Flax to build a CNN and shows transfer learning by freezing a pre trained feature extractor and only training the final layers. Tutorial 5 takes on multivariate time series forecasting with a vanilla recurrent network in pure JAX and a more solid LSTM built with Equinox. Tutorial 6 is a step by step build of a small generative transformer, called Nano GPT in the README, with character level tokenization, causal self attention, an MLP block, and visualizations of neuron activations. The last five tutorials reach modern systems. Tutorial 7 starts with variational autoencoders, then turns the decoder into a generator inside a Wasserstein GAN to make synthetic face textures. Tutorial 8 builds a small text to image pipeline with a transformer based text encoder and a conditioned U Net, using classifier free guidance to shape the output. Tutorial 9 trains a self driving racing agent with a Soft Actor Critic algorithm written from scratch in JAX and Optax, with DeepMind's Reverb library handling prioritized and hindsight experience replay. Tutorial 10 walks the full GPT 2 lifecycle, including grouped query attention, unsupervised pre training, supervised fine tuning, LoRA, RLHF, and distributed training with pjit and FSDP. Tutorial 11 is on explainability, where the reader poisons CIFAR 10 with a watermark on dogs, trains a CNN that picks up the shortcut, and then uses methods such as Integrated Gradients to detect the cheating. The README presents the notebooks as starting points that can be adapted to real engineering work rather than just classroom exercises.

Copy-paste prompts

Prompt 1
Walk me through tutorial 3's from-scratch SGD, Nesterov, RMSProp, and Adam implementations
Prompt 2
Show me how tutorial 6 builds a character-level Nano GPT with causal self attention in JAX
Prompt 3
Explain the full GPT-2 lifecycle tutorial covering grouped query attention, LoRA, RLHF, and FSDP
Prompt 4
Reproduce the CIFAR-10 watermark backdoor experiment and run Integrated Gradients on the result
Open on GitHub → Explain another repo

Generated 2026-05-22 · Model: sonnet-4-6 · Verify against the repo before relying on details.