explaingit

google/flax

7,195Jupyter NotebookAudience · researcherComplexity · 4/5Setup · moderate

TLDR

Flax is Google's Python library for building and training neural networks on top of JAX, letting researchers define models as plain Python classes that run fast on CPUs and AI accelerators.

Mindmap

mindmap
  root((flax))
    What it does
      Neural network library
      Built on JAX
      Plain Python classes
    APIs
      NNX current API
      Linen legacy API
    Tech stack
      Python
      JAX
      GPU and TPU support
    Use cases
      Image classification
      Language model inference
      Research prototyping
    Audience
      AI researchers
      Google DeepMind team
Click or tap to explore — scroll the page freely

Code map

Detail Auto

An interactive map of this repo's files and how they connect — its source is parsed live in your browser. Click Visualize to build it.

filefunction / class

Things people build with this

USE CASE 1

Build and train a custom neural network for image classification using plain Python classes with Flax NNX.

USE CASE 2

Run inference on the Gemma language model using Flax's included example code.

USE CASE 3

Experiment with attention mechanisms and normalization layers for a research model using Flax's composable building blocks.

USE CASE 4

Prototype a new AI architecture in Flax that runs fast on GPU or TPU accelerators via JAX.

Tech stack

PythonJAXJupyter Notebook

Getting it running

Difficulty · moderate Time to first run · 30min

Requires Python 3.8+ and JAX, JAX setup varies by hardware (CPU vs GPU vs TPU) and may need extra steps.

No license information is provided in this explanation.

In plain English

Flax is a Python library for building and training neural networks, developed by Google. It sits on top of JAX, which is Google's framework for fast numerical computation that runs well on both regular CPUs and specialized AI accelerators. If you have heard of PyTorch or TensorFlow as ways to build AI models, Flax is Google's answer to that same need, but built specifically around JAX. The library gives you the building blocks to describe the structure of a neural network: layers that transform numbers, normalization steps that keep training stable, attention mechanisms used in modern language models, dropout for preventing overfitting, and more. You assemble these into a model by writing plain Python classes, which makes it easier to read, debug, and modify compared to older frameworks that required more abstract patterns. Flax has gone through two major API designs. The older one, called Linen, was released in 2020 and is still documented separately. The current one, called NNX and released in 2024, lets you work with models as regular Python objects that can be inspected and changed directly, which simplifies common research tasks. The library is aimed at researchers who want to experiment freely. It is installed with a single pip command and requires Python 3.8 or later. It comes with example code for tasks like classifying handwritten digits and running inference on the Gemma language model. Flax is maintained by a team at Google DeepMind but is not an official Google product.

Copy-paste prompts

Prompt 1
Using Flax NNX, write a Python class for a convolutional neural network that classifies handwritten digits from the MNIST dataset.
Prompt 2
Show me how to load and run inference on the Gemma language model using Flax's provided example code.
Prompt 3
I want to add a custom attention layer to my Flax NNX model. Write the module class and show how to inspect its parameters directly after initialization.
Prompt 4
How do I train a Flax NNX model on a GPU, save a checkpoint, and reload it to resume training from where I left off?
Prompt 5
Convert a Flax Linen model to the new NNX API, show me the before and after code side by side with the key differences explained.
Open on GitHub → Explain another repo

← google on gitmyhub — every repo by this author, as a profile.

Verify against the repo before relying on details.