Tiny LLM go brrrrr
March 31, 2026
In part [2/2] we will cover KV caching.
📝 Note
I build on part 1/2 here, so I highly recommend giving that a read before starting this one.
Table of Contents
Pre-requisite
I highly recommend using nix with this shell.nix config. This should install everything you need.
{ pkgs ? import <nixpkgs> {} }:
pkgs.mkShellNoCC {
packages = [
(pkgs.python3.withPackages (ps: [
ps.jax
ps.optax
]))
];
}📝 Note
If VS Code is not playing nice with this then do the following:
- In the VS Code terminal run
nix-shellcommand - Then run
which python CMD+SHIFT+Pand typePython: Select Interpreterthen selectEnter interpreter path...and add the path from 2.
Let's go
We implemented multi-head attention in part 1/2 and used that to generate text.
When measuring the performance of generation the key metrics are[1]:
- Time to first token (TTFT)
- Time per output token (TPOT)
- Inter-token latency (ITL)
- End-to-end latency (E2EL)
Let's instrument these in our generate function.
from datetime import datetime
def generate(params, prompt, rand_key):
token_times = []
for i in range(len(prompt), CONTEXT_LENGTH):
token_start = datetime.now()
encoded_prompt = jax.numpy.array(encode(prompt))
padded = jax.numpy.zeros((CONTEXT_LENGTH,), dtype=jax.numpy.int32)
inputs = padded.at[0:len(prompt)].set(encoded_prompt)
logist = forward(params, inputs[None, :])
predictions = logist[0, len(prompt) - 1]
rand_key, subkey = jax.random.split(rand_key)
prediction = jax.random.categorical(subkey, predictions / 0.8)
decoded_prediction = decode([int(prediction)])
prompt += decoded_prediction
token_end = datetime.now()
token_times.append((token_end - token_start).total_seconds() * 1000)
print(decoded_prediction, end="", flush=True)
end_time = datetime.now()
print()
ttft = token_times[0] # Time to First Token (ms)
tpot = sum(token_times[1:]) / (len(token_times) -1) # Time Per Output Token (ms)
avg_itl = tpot # Inter-Token Latencies is bacially tpot for a single request
print(f"TTFT: {ttft:.2f}ms")
print(f"TPOT: {tpot:.2f}ms")
print(f"Avg ITL: {avg_itl:.2f}ms")
print(f"E2ET: {sum(token_times):.2f}ms")If you run this you should get something like:
> to be or not to be
city;
You had been with by you; feel which he common comediates,
I have not draws instruck my tongue
My tent
TTFT: 994.40ms
TPOT: 53.94ms
Avg ITL: 53.94msBefore we dive into KV Cache let's add @jax.jit[2] to forward and see if that gives us some performance gains.
> to be or not to be
city;
You had been with by you; feel which he common comediates,
I have not draws instruck my tongue
My tent
TTFT: 511.33ms
TPOT: 51.47ms
Avg ITL: 51.47ms
E2EL: 6121.61msOk that was neat! But can we do better?
KV Cache
Let's think about what happens during generation. When we generate the 10th token, we run the full forward pass on all 10 tokens. Then when we generate the 11th token, we run the forward pass on all 11 tokens. See the problem? We're recomputing attention for the previous tokens each time.
In attention, for each token we compute:
- Q (Query): "What am I looking for?"
- K (Key): "What do I contain?"
- V (Value): "Here's my content"
The K and V for tokens 1-10 are exactly the same when we generate token 11. We're wasting compute recalculating them. KV caching solves this by storing the K and V values from previous tokens and reusing them.
💡 Intuition: Why Cache K and V but not Q?
During generation, we only need the Query for the new token. The new token asks "what should I attend to?" and we look up the answer using the cached Keys and Values from all previous tokens.
For KV caching we are going to break generation into 2 parts:
- Prefill: The initial values of the KV caches and the first token are going to be generated in this phase. We should not see an improvement here.
- Decode: The new tokens are going to be generated in this phase. We will only send the last token instead of the whole
prompt + last tokenlike we were doing before. This phase should see a speedup and the TPOT should improve because we will use the caches from prefill.
💡 Intuition: Multi-head means multiple Ks and Vs are cached
We have multiple attention blocks, so we will need a KV cache per block.
Let's start by updating multihead_attention to also return K and V:
weighted_sum = jax.numpy.reshape(
weighted_sum, (batch_size, context_length, EMBED_DIM)
)
- return weighted_sum @ params["W_o"]
+ return (weighted_sum @ params["W_o"], K, V)And update transformer_block to pass them through:
def transformer_block(params, x):
- attention_out = multihead_attention(params, layer_norm(params["ln1"], x))
- x = x + attention_out
+ attention, K, V = multihead_attention(params, layer_norm(params["ln1"], x))
+ x = x + attention
x = x + ffn(params, layer_norm(params["ln2"], x))
- return x
+ return (x, K, V)Now let's write a new forward_prefill function. This function will build the caches and return them along with the first token.
def forward_prefill(params, inputs):
x = embed_prefill(params, inputs)
batch_size = inputs.shape[0]
prompt_length = inputs.shape[1]
kvs = []
for block_params in params["blocks"]:
x, K, V = transformer_block(block_params, x)
k_cache = jax.numpy.zeros((batch_size, CONTEXT_LENGTH, EMBED_DIM))
v_cache = jax.numpy.zeros((batch_size, CONTEXT_LENGTH, EMBED_DIM))
k_cache = k_cache.at[:, :prompt_length].set(K)
v_cache = v_cache.at[:, :prompt_length].set(V)
kvs.append((k_cache, v_cache))
return (x @ params["W_o"], kvs)For forward_prefill we will need to add a new embedding function. This is because we won't pad the input to be CONTEXT_LENGTH anymore, so our original embed function won't work.
def embed_prefill(params, inputs):
# we only generate positional embeddings till the prompt length, we don't pad to CONTEXT_LENGTH
return (
params["token_embedding"][inputs]
+ params["positional_embedding"][: inputs.shape[1]]
)Now let's write another function called forward_decode that will use these caches and generate the remaining tokens.
@jax.jit
def forward_decode(params, inputs, position, kvs):
# simulate actual position
x = embed_at(params, inputs, position)
new_kvs = []
for i, block_params in enumerate(params["blocks"]):
x, k_cache, v_cache = transformer_block_decode(block_params, x, position, kvs[i][0], kvs[i][1])
new_kvs.append((k_cache, v_cache))
return (x @ params["W_o"], new_kvs)Since we are only sending 1 token we have to keep track of the position. We'll add another embedding function and a new transformer_block_decode function that takes the previous K and V values as input.
def embed_at(params, inputs, position):
return params["token_embedding"][inputs] + params["positional_embedding"][position]
def transformer_block_decode(params, x, position, pre_K, pre_V):
attention, K, V = multihead_attention_cached(
params, layer_norm(params["ln1"], x), position, pre_K, pre_V
)
x = x + attention
x = x + ffn(params, layer_norm(params["ln2"], x))
return x, K, VWe will also write a new multihead_attention_cached function which will take the previous K and V values as input.
# inputs is always just (1, 1, EMBED_DIM)
# we have a single batch with just one new token
# k_cache, v_cache are (1, CONTEXT_LENGTH, EMBED_DIM)
def multihead_attention_cached(params, inputs, position, k_cache, v_cache):
batch_size = inputs.shape[0]
context_length = inputs.shape[1]
Q = inputs @ params["W_q"]
K_new = inputs @ params["W_k"]
V_new = inputs @ params["W_v"]
k_cache = k_cache.at[:, position].set(K_new[:, 0])
v_cache = v_cache.at[:, position].set(V_new[:, 0])
Q = jax.numpy.reshape(Q, (batch_size, 1, NUM_HEADS, HEAD_DIM))
Q = Q.transpose(0, 2, 1, 3)
K = jax.numpy.reshape(k_cache, (batch_size, CONTEXT_LENGTH, NUM_HEADS, HEAD_DIM))
K = K.transpose(0, 2, 1, 3)
V = jax.numpy.reshape(v_cache, (batch_size, CONTEXT_LENGTH, NUM_HEADS, HEAD_DIM))
V = V.transpose(0, 2, 1, 3)
# attention scores
attention_score = Q @ K.transpose(0, 1, 3, 2)
# scale attention scores otherwise gradients will vanish
attention_score = attention_score / (HEAD_DIM**0.5)
# causal mask so that we only look at the previos tokens
causal_mask = jax.numpy.where(
jax.numpy.arange(CONTEXT_LENGTH) > position,
-jax.numpy.inf,
0.0
)
attention_score = attention_score + causal_mask
attention_weights = jax.nn.softmax(attention_score)
weighted_sum = attention_weights @ V
# convert back to original shape
weighted_sum = weighted_sum.transpose(
0, 2, 1, 3
) # (BATCH_SIZE, CONTEXT_LENGTH, NUM_HEADS, HEAD_DIMS)
weighted_sum = jax.numpy.reshape(
weighted_sum, (batch_size, context_length, EMBED_DIM)
)
return (weighted_sum @ params["W_o"], k_cache, v_cache)Ok now let's update the generate function!
def generate(params, prompt, rand_key):
token_times = []
token_start = datetime.now()
inputs = encode(prompt)
# note: the [None, :] is to add a batch dimension
logits, kvs = forward_prefill(params, jax.numpy.array(inputs)[None, :])
# 1st batch last element
predictions = logits[0, -1]
rand_key, subkey = jax.random.split(rand_key)
prediction = jax.random.categorical(subkey, predictions / 0.8)
token_end = datetime.now()
token_times.append((token_end-token_start).total_seconds() * 1000)
# Print the first generated token
# This is measure of time to first token!
print(decode([int(prediction)]), end="", flush=True)
for i in range(len(prompt), CONTEXT_LENGTH):
token_start = datetime.now()
next_token = jax.numpy.array([int(prediction)])
logits, kvs = forward_decode(
params, next_token[None, :], i, kvs
) # this add the extra batch dimension
predictions = logits[0, 0]
rand_key, subkey = jax.random.split(rand_key)
prediction = jax.random.categorical(subkey, predictions / 0.8)
token_end = datetime.now()
token_times.append((token_end-token_start).total_seconds() * 1000)
# Print each new token as it's generated
print(decode([int(prediction)]), end="", flush=True)
# Print newline at the end
print()
ttft = token_times[0] # Time to First Token (ms)
tpot = sum(token_times[1:]) / (len(token_times) -1) # Time Per Output Token (ms)
avg_itl = tpot # Inter-Token Latencies is bacially tpot for a single request
print(f"TTFT: {ttft:.2f}ms")
print(f"TPOT: {tpot:.2f}ms")
print(f"Avg ITL: {avg_itl:.2f}ms")
print(f"E2EL: {sum(token_times):.2f}ms")Running the program now should output the following:
> to be or not to be
city;
You had been with by you; feel which he common comediates,
I have not draws instruck my tongue
My tent
TTFT: 522.87ms
TPOT: 3.94ms
Avg ITL: 3.94ms
E2EL: 956.16msLook at that TPOT drop! We went from ~50ms to ~4ms per token. That's a 12x speedup on the decode phase. The TTFT stays roughly the same because we still need to process the full prompt on the first pass (prefill).
💡 Intuition: Prefill vs Decode characteristics
The two phases of generation have different characteristics:
Prefill processes the entire prompt at once. This is compute-bound because we're doing a lot of matrix multiplications. KV cache doesn't help here since we're computing K and V for the first time.
Decode generates one token at a time. This is memory-bound because we're mostly just reading the cached K and V values. KV cache shines here since we only compute K and V for one new token.
This is why production LLM systems often report TTFT and TPOT separately. They have very different optimization strategies!
KV caching isn't free. We're trading compute for memory. For each layer, we store:
- K cache:
(BATCH_SIZE, NUM_HEADS, CONTEXT_LENGTH, HEAD_DIM) - V cache:
(BATCH_SIZE, NUM_HEADS, CONTEXT_LENGTH, HEAD_DIM)
With our tiny model that's not much, but for a real LLM, the KV cache can be several gigabytes per request.
If you liked this post please share it with your friends!
You can find the complete implementation here.
📝 Note
You must be wondering why I pre-allocate k_cache and v_cache to be CONTEXT_LENGTH. It's because that makes the size of all inputs the same across calls. This is a common jax.jit gotcha where JAX recompiles every time the function is called with different input shapes. If we did not have a constant size for the caches and concatenated the values to the previous K and V values, then every call to multihead_attention_cached would cause recompilation and that would massively slow down the program.
I managed to do exactly the above. See how slow the program is by checking out this commit.
References
These metrics are commonly used in production LLM systems. See the vLLM paper for more details on how they're measured and optimized. ↩︎
@jax.jitcompiles your function using XLA (Accelerated Linear Algebra) for faster execution. It traces your function once and then runs the optimized version on subsequent calls. ↩︎