DEV Community

Hector Li
Hector Li

Posted on

Bringing 2-Bit Quantization to ONNX Runtime's WebGPU Backend

A story of five bugs, bit-level debugging, and running transformer models at 2-bit precision in the browser. Here's the PR


Background

ONNX Runtime's MatMulNBits operator supports low-bit quantized matrix multiplication — packing weight values into 2, 4, or 8 bits per element. The WebGPU execution provider (both the native C++ path and the JavaScript/JSEP path) already supported 4-bit (Q4) quantization, but 2-bit (Q2) was blocked or broken. Our goal: make Q2 with zero points work correctly end-to-end so that 2-bit quantized transformer models run accurately in the browser via WebGPU.

What seemed like a single feature gap turned out to be five distinct bugs, each hidden behind the last.


Bug 1: The Gate — Hard-Coded Rejection of Q2 + Zero Points

The first issue was immediate: attempting to run a 2-bit model with zero points threw a runtime error:

"Currently, zero points are not supported for Q2 quantization."

Two enforcement guards explicitly blocked Q2:

  • Native WebGPU EP (matmul_nbits.cc): An ORT_ENFORCE(nbits != 2) when zero points were present.
  • JSEP C++ kernel (matmul_nbits.h): ORT_ENFORCE(nbits_ == 4) — only Q4 was allowed at all.

Additionally, the WGSL zero-point extraction template (matmul_nbits_zero_pt.wgsl.template) had #elif n_bits == 2 but was missing the bit_mask constant, so even if the guard were removed, the shader would malfunction.

Fix: Remove the enforcement blocks, add const bit_mask = 0x3u; for Q2, guard the DP4A path (which uses a hardcoded LUT assuming zero_point=2) to skip Q2 with custom zero points.


Bug 2: Zero Point Buffer Stride Miscalculation

With the gates removed, tests ran — but produced wrong results. The root cause was in how zero_blocks_per_col was computed.

Zero points are packed into bytes: for Q4, two values per byte; for Q2, four values per byte. Each column's zero points are byte-aligned, so the shader uses a flat linear stride to skip between columns. The original formula:

uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0
    ? n_blocks_per_col : n_blocks_per_col + 1;
Enter fullscreen mode Exit fullscreen mode

This "+1" was a Q4 shortcut. For Q2 with n_blocks_per_col = 6 (e.g., K=384, block_size=64), the stride needs to round up to the next multiple of 4 (values per byte), not just add 1.

Fix: Proper ceiling-to-multiple formula:

const uint32_t zp_elements_per_byte = 8 / nbits;
uint32_t zero_blocks_per_col =
    (n_blocks_per_col + zp_elements_per_byte - 1)
    / zp_elements_per_byte * zp_elements_per_byte;
Enter fullscreen mode Exit fullscreen mode

Bug 3: Shift Formula Crosses Byte Boundaries

Now the native EP worked, but the JSEP path (the browser-facing JavaScript shaders in matmulnbits.ts) still produced garbage.

For Q4, each u32 word holds 8 values — processed in a single pass. For Q2, each word holds 16 values, requiring 2 passes of 8. The original shift used pass * 8, meaning pass 1 shifted by 8 bits — crossing from one byte into the next, mixing values from different bytes.

Fix: lowerShift = pass * bits * 4 — for Q2 this gives shifts of 0 and 4, staying within each byte's boundaries.


Bug 4: Value Extraction Ordering — The Nibble-Spread

After the shift fix, output changed but was still wrong. Deeper analysis revealed a fundamental ordering problem.

The Q4 extraction pattern unpack4xU8(b_value & 0x0F0F0F0F) works because it extracts the same bit position from all 4 bytes simultaneously — and for Q4, that gives 4 sequential values (one per byte). But for Q2, the same technique extracts bit position 0-1 from bytes 0, 1, 2, and 3 — producing values v0, v4, v8, v12 instead of v0, v1, v2, v3. The A-data is sequential, so a[2] * b[8] is computed instead of a[2] * b[2].

Fix: A "nibble-spread" technique that reorganizes bytes before extraction. Each pass takes 2 bytes (8 sequential values), spreads each nibble (4 bits = two Q2 values) into its own byte of a synthetic u32, then applies the standard unpack4xU8 + mask pattern:

let half_word = b_value >> (pass * 16u);
let byte_lo = half_word & 0xFFu;
let byte_hi = (half_word >> 8u) & 0xFFu;
let spread_word = (byte_lo & 0xFu)
    | ((byte_lo >> 4u) << 8u)
    | ((byte_hi & 0xFu) << 16u)
    | ((byte_hi >> 4u) << 24u);
b_value_lower = unpack4xU8(spread_word & 0x03030303u);
b_value_upper = unpack4xU8((spread_word >> 2u) & 0x03030303u);
Enter fullscreen mode Exit fullscreen mode

This was applied to both the general shader path and the BlockSize32 optimized path.


Bug 5: A-Data Double-Advancement

After the nibble-spread fix, the result changed again — closer, but still incorrect. A Python trace script finally pinpointed the last bug: the A-data offset for pass 1 was wrong.

In the multi-pass loop, pass 0 reads A values via a loop that increments input_offset 8 times. Pass 1 then computed its starting offset as input_offset + 8/aComponents — but input_offset had already been advanced by pass 0's loop. This double-counted the offset, causing pass 1 to read A[16] instead of A[8], skipping 8 activation values entirely.

Fix: Pass 1 simply uses input_offset directly — it already points to exactly where pass 0 left off:

// Before (bug): input_offset + ${(pass * 8) / aComponents}
// After (fix):  input_offset
Enter fullscreen mode Exit fullscreen mode

After this fix, the 2-bit quantized model produced correct results on WebGPU, matching CPU output.


Parameterizing the Shader for Variable Bit Widths

Beyond the bug fixes, the JSEP shader needed systematic parameterization. Hard-coded Q4 assumptions were replaced with attributes.bits-driven constants throughout:

Concept Q4 Q2
Values per u32 word 8 16
Passes per word 1 2
Bit mask 0x0F0F0F0Fu 0x03030303u
Default zero point 8 2
ZP values per byte 2 4
ZP byte mask 0xFu 0x3u
word_offset increment 8/aComponents 16/aComponents

Test Coverage

We added a MatMul2BitsWebGpu test suite to exercise the Q2 path on the WebGPU EP:

  • Symmetric & asymmetric (with/without zero points)
  • Multiple block sizes (16, 32, 64, 128) — block_size=64 is the critical case where n_blocks_per_col is not a multiple of 4, exercising the zero-point padding logic
  • Varying dimensions (K=16 to 1024, N=1 to 384) — covering single-word and multi-word extraction patterns
  • Batch tests (M=1, 4, 100)

All 9 test configurations pass on WebGPU EP, with results matching CPU baseline within tolerance.


Files Changed

File Change
matmul_nbits.cc Remove Q2+ZP block, fix zero_blocks_per_col, guard DP4A
matmul_nbits_zero_pt.wgsl.template Add bit_mask = 0x3u for Q2
matmul_nbits.h Allow nbits == 2 in JSEP kernel
matmulnbits.ts Parameterize for Q2, shift fix, nibble-spread, A-offset fix
matmul_2bits_test.cc WebGPU-specific Q2 test suite

Takeaways

  1. One feature, five bugs — each fix revealed the next layer of incorrectness. Without tests that compared against a CPU baseline, any single fix would have appeared to "do something" while still being wrong.

  2. Bit-packing extraction is subtle — the Q4 pattern of "mask the same bits from all 4 bytes" only works because Q4 has exactly one value per nibble per byte. Q2 breaks that assumption fundamentally.

  3. Trace scripts are essential — Python scripts that simulate shader logic step-by-step (nibble-spread verification, A-offset tracking) were what ultimately identified bugs 4 and 5 after code-reading alone proved insufficient.

  4. Parameterize, don't fork — rather than creating a separate Q2 shader, making the existing shader bit-width-aware keeps the code maintainable and makes future N-bit support (Q3, Q8) straightforward.

Top comments (0)