explaingit

sanchit-gandhi/whisper-jax

4,689Jupyter NotebookAudience · researcherComplexity · 3/5Setup · moderate

TLDR

Whisper JAX is a faster reimplementation of OpenAI's Whisper speech-to-text model using JAX, capable of transcribing 30 minutes of audio in about 30 seconds on TPU hardware, up to 70 times faster than the original.

Mindmap

mindmap
  root((repo))
    What it does
      Speech to text
      Audio translation
      Timestamped output
    Speed advantage
      JAX compilation
      Cached compiled fn
      Batched chunks
    Hardware support
      TPU fastest
      GPU supported
      CPU fallback
    Model sizes
      Tiny 39M params
      Large model
      Multilingual variants
Click or tap to explore — scroll the page freely

Code map

Detail Auto

An interactive map of this repo's files and how they connect — its source is parsed live in your browser. Click Visualize to build it.

filefunction / class

Things people build with this

USE CASE 1

Transcribe a long audio file into text at up to 70 times the speed of the original Whisper model by running it on a TPU.

USE CASE 2

Batch-process audio files by splitting them into 30-second chunks and processing them in parallel across multiple accelerators.

USE CASE 3

Translate spoken audio from any supported language into English text by changing a single pipeline parameter.

USE CASE 4

Generate a timestamped transcript of a recording that marks which words were spoken at which point in the audio.

Tech stack

PythonJAXJupyter Notebook

Getting it running

Difficulty · moderate Time to first run · 30min

Requires installing a compatible version of JAX separately before pip-installing this library, TPU setup is easiest via the provided Kaggle notebook.

License terms are not specified in the repository documentation.

In plain English

Whisper JAX is a faster reimplementation of OpenAI's Whisper speech-to-text model, rewritten to run on a different computing framework called JAX. OpenAI's original Whisper converts audio files into text transcripts across many languages, and this project takes that same model and makes it run dramatically faster: up to 70 times quicker than the original, particularly when run on Google's TPU hardware. The practical result is that 30 minutes of audio can be transcribed in roughly 30 seconds. The core idea is that JAX compiles the transcription function the first time it runs, then caches the compiled version so every subsequent call is much faster. There is a one-time wait the first time you process audio, but after that the speed difference is substantial. The library also supports batching, which splits a long audio file into 30-second chunks and processes them in parallel across multiple hardware accelerators. The project reports this gives roughly a 10x additional speedup with less than 1% reduction in accuracy. Using the library looks like loading a pipeline, pointing it at an audio file, and getting text back. The same pipeline can transcribe speech in its original language or translate it into English by changing a single parameter. It can also return timestamps alongside the transcript, which marks which words were spoken at which points in the recording. The library works with any of the official Whisper model sizes, from the tiny version with 39 million parameters up to the large model, as well as multilingual variants. It runs on CPU, GPU, and TPU, though the largest speed gains come from TPU environments. For users who want more control, the library also exposes lower-level building blocks that match the structure of the Hugging Face Transformers library. Installing it requires Python 3.9, a compatible version of the JAX package installed separately, and then a pip install from the GitHub repository. A Kaggle notebook is provided to demonstrate the full setup on a cloud TPU environment.

Copy-paste prompts

Prompt 1
Set up Whisper JAX on a Kaggle TPU notebook. Walk me through installing JAX for TPU and loading the large Whisper model pipeline.
Prompt 2
I have a 2-hour audio file. Show me how to use Whisper JAX to transcribe it with chunked batching enabled for maximum speed.
Prompt 3
Transcribe an audio file in French using Whisper JAX and also produce an English translation at the same time using the pipeline parameters.
Prompt 4
Run Whisper JAX on GPU instead of TPU. What changes do I need to make to the JAX installation and the pipeline configuration?
Prompt 5
Use Whisper JAX to generate a timestamped transcript and format the output as an SRT subtitle file.
Open on GitHub → Explain another repo

← sanchit-gandhi on gitmyhub — every repo by this author, as a profile.

Verify against the repo before relying on details.