A story of five bugs, bit-level debugging, and running transformer models at 2-bit precision in the browser. Here's the PR
Background
ONNX Runtime's MatMulNBits operator supports low-bit quantized matrix multiplication — packing weight values into 2, 4, or 8 bits per element. The WebGPU execution provider (both the native C++ path and the JavaScript/JSEP path) already supported 4-bit (Q4) quantization, but 2-bit (Q2) was blocked or broken. Our goal: make Q2 with zero points work correctly end-to-end so that 2-bit quantized transformer models run accurately in the browser via WebGPU.
What seemed like a single feature gap turned out to be five distinct bugs, each hidden behind the last.
Bug 1: The Gate — Hard-Coded Rejection of Q2 + Zero Points
The first issue was immediate: attempting to run a 2-bit model with zero points threw a runtime error:
"Currently, zero points are not supported for Q2 quantization."
Two enforcement guards explicitly blocked Q2:
-
Native WebGPU EP (matmul_nbits.cc): An
ORT_ENFORCE(nbits != 2)when zero points were present. -
JSEP C++ kernel (matmul_nbits.h):
ORT_ENFORCE(nbits_ == 4)— only Q4 was allowed at all.
Additionally, the WGSL zero-point extraction template (matmul_nbits_zero_pt.wgsl.template) had #elif n_bits == 2 but was missing the bit_mask constant, so even if the guard were removed, the shader would malfunction.
Fix: Remove the enforcement blocks, add const bit_mask = 0x3u; for Q2, guard the DP4A path (which uses a hardcoded LUT assuming zero_point=2) to skip Q2 with custom zero points.
Bug 2: Zero Point Buffer Stride Miscalculation
With the gates removed, tests ran — but produced wrong results. The root cause was in how zero_blocks_per_col was computed.
Zero points are packed into bytes: for Q4, two values per byte; for Q2, four values per byte. Each column's zero points are byte-aligned, so the shader uses a flat linear stride to skip between columns. The original formula:
uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0
? n_blocks_per_col : n_blocks_per_col + 1;
This "+1" was a Q4 shortcut. For Q2 with n_blocks_per_col = 6 (e.g., K=384, block_size=64), the stride needs to round up to the next multiple of 4 (values per byte), not just add 1.
Fix: Proper ceiling-to-multiple formula:
const uint32_t zp_elements_per_byte = 8 / nbits;
uint32_t zero_blocks_per_col =
(n_blocks_per_col + zp_elements_per_byte - 1)
/ zp_elements_per_byte * zp_elements_per_byte;
Bug 3: Shift Formula Crosses Byte Boundaries
Now the native EP worked, but the JSEP path (the browser-facing JavaScript shaders in matmulnbits.ts) still produced garbage.
For Q4, each u32 word holds 8 values — processed in a single pass. For Q2, each word holds 16 values, requiring 2 passes of 8. The original shift used pass * 8, meaning pass 1 shifted by 8 bits — crossing from one byte into the next, mixing values from different bytes.
Fix: lowerShift = pass * bits * 4 — for Q2 this gives shifts of 0 and 4, staying within each byte's boundaries.
Bug 4: Value Extraction Ordering — The Nibble-Spread
After the shift fix, output changed but was still wrong. Deeper analysis revealed a fundamental ordering problem.
The Q4 extraction pattern unpack4xU8(b_value & 0x0F0F0F0F) works because it extracts the same bit position from all 4 bytes simultaneously — and for Q4, that gives 4 sequential values (one per byte). But for Q2, the same technique extracts bit position 0-1 from bytes 0, 1, 2, and 3 — producing values v0, v4, v8, v12 instead of v0, v1, v2, v3. The A-data is sequential, so a[2] * b[8] is computed instead of a[2] * b[2].
Fix: A "nibble-spread" technique that reorganizes bytes before extraction. Each pass takes 2 bytes (8 sequential values), spreads each nibble (4 bits = two Q2 values) into its own byte of a synthetic u32, then applies the standard unpack4xU8 + mask pattern:
let half_word = b_value >> (pass * 16u);
let byte_lo = half_word & 0xFFu;
let byte_hi = (half_word >> 8u) & 0xFFu;
let spread_word = (byte_lo & 0xFu)
| ((byte_lo >> 4u) << 8u)
| ((byte_hi & 0xFu) << 16u)
| ((byte_hi >> 4u) << 24u);
b_value_lower = unpack4xU8(spread_word & 0x03030303u);
b_value_upper = unpack4xU8((spread_word >> 2u) & 0x03030303u);
This was applied to both the general shader path and the BlockSize32 optimized path.
Bug 5: A-Data Double-Advancement
After the nibble-spread fix, the result changed again — closer, but still incorrect. A Python trace script finally pinpointed the last bug: the A-data offset for pass 1 was wrong.
In the multi-pass loop, pass 0 reads A values via a loop that increments input_offset 8 times. Pass 1 then computed its starting offset as input_offset + 8/aComponents — but input_offset had already been advanced by pass 0's loop. This double-counted the offset, causing pass 1 to read A[16] instead of A[8], skipping 8 activation values entirely.
Fix: Pass 1 simply uses input_offset directly — it already points to exactly where pass 0 left off:
// Before (bug): input_offset + ${(pass * 8) / aComponents}
// After (fix): input_offset
After this fix, the 2-bit quantized model produced correct results on WebGPU, matching CPU output.
Parameterizing the Shader for Variable Bit Widths
Beyond the bug fixes, the JSEP shader needed systematic parameterization. Hard-coded Q4 assumptions were replaced with attributes.bits-driven constants throughout:
| Concept | Q4 | Q2 |
|---|---|---|
| Values per u32 word | 8 | 16 |
| Passes per word | 1 | 2 |
| Bit mask | 0x0F0F0F0Fu |
0x03030303u |
| Default zero point | 8 | 2 |
| ZP values per byte | 2 | 4 |
| ZP byte mask | 0xFu |
0x3u |
| word_offset increment | 8/aComponents | 16/aComponents |
Test Coverage
We added a MatMul2BitsWebGpu test suite to exercise the Q2 path on the WebGPU EP:
- Symmetric & asymmetric (with/without zero points)
-
Multiple block sizes (16, 32, 64, 128) — block_size=64 is the critical case where
n_blocks_per_colis not a multiple of 4, exercising the zero-point padding logic - Varying dimensions (K=16 to 1024, N=1 to 384) — covering single-word and multi-word extraction patterns
- Batch tests (M=1, 4, 100)
All 9 test configurations pass on WebGPU EP, with results matching CPU baseline within tolerance.
Files Changed
| File | Change |
|---|---|
| matmul_nbits.cc | Remove Q2+ZP block, fix zero_blocks_per_col, guard DP4A |
| matmul_nbits_zero_pt.wgsl.template | Add bit_mask = 0x3u for Q2 |
| matmul_nbits.h | Allow nbits == 2 in JSEP kernel |
| matmulnbits.ts | Parameterize for Q2, shift fix, nibble-spread, A-offset fix |
| matmul_2bits_test.cc | WebGPU-specific Q2 test suite |
Takeaways
One feature, five bugs — each fix revealed the next layer of incorrectness. Without tests that compared against a CPU baseline, any single fix would have appeared to "do something" while still being wrong.
Bit-packing extraction is subtle — the Q4 pattern of "mask the same bits from all 4 bytes" only works because Q4 has exactly one value per nibble per byte. Q2 breaks that assumption fundamentally.
Trace scripts are essential — Python scripts that simulate shader logic step-by-step (nibble-spread verification, A-offset tracking) were what ultimately identified bugs 4 and 5 after code-reading alone proved insufficient.
Parameterize, don't fork — rather than creating a separate Q2 shader, making the existing shader bit-width-aware keeps the code maintainable and makes future N-bit support (Q3, Q8) straightforward.
Top comments (0)