fix: Include all Dense projection layers in ONNX export (output dim 48)

#3

Problem

The current ONNX exports (model.onnx and model_int8.onnx) are missing the final Dense projection layers from PyLate's modules_list pipeline. This causes the ONNX models to output the intermediate hidden dimension instead of the correct final embedding dimension.

mxbai-edge-colbert-v0-17m:

  • Current ONNX output: 512 dimensions (intermediate)
  • Expected output: 48 dimensions (after all Dense projections)

Root Cause

The ONNX export was done using the standard HuggingFace Transformers export pipeline, which only picks up the transformer backbone and possibly the first linear head. It does not include the additional Dense modules stored in the separate 1_Dense/, 2_Dense/ directories that are part of PyLate's modules_list architecture.

Fix

Re-exported by wrapping the full pipeline (Transformer + all Dense layers) into a single module before ONNX export:

  • model.onnx: fp32, opset 17, output dim = 48
  • model_int8.onnx: int8 dynamic quantization with projection layers kept in fp32 (they are small and precision-sensitive)

Both files have been verified for:

  • Correct output dimensions
  • ONNX model validity (onnx.checker)
  • Numerical consistency with PyTorch (fp32)

Benchmark Results: Fixed ONNX vs Original PyLate Checkpoint

Tested on 6 NanoBEIR datasets using brute-force MaxSim scoring. The fixed ONNX fp32 model produces identical results to the original PyLate checkpoint (embedding max diff ~2-3e-07).

nDCG@10

Dataset PyLate (original) ONNX fp32 ONNX int8 fp32 diff int8 diff
FiQA 0.5261 0.5261 0.4922 +0.0000 -0.0339
SciFact 0.7943 0.7943 0.7893 +0.0000 -0.0051
NFCorpus 0.3683 0.3683 0.3625 +0.0000 -0.0058
SCIDOCS 0.3898 0.3898 0.3507 +0.0000 -0.0392
HotpotQA 0.8918 0.8918 0.8609 +0.0000 -0.0309
NQ 0.7092 0.7092 0.6460 +0.0000 -0.0632

Key findings:

  • fp32 ONNX matches PyLate exactly on all 6 datasets (nDCG@10 diff = 0.0000)
  • int8 ONNX shows expected small degradation from dynamic quantization (projection layers kept in fp32)
  • Embedding max diff between fp32 ONNX and PyLate: ~2-3e-07 (numerically identical)

Hello!

Sentence Transformers maintainer here, I don't work at mixedbread, but I wanted to clarify why this is the case currently. This PyLate model consists of 3 modules: Transformer, Dense and Dense. Like in Sentence Transformers, there is currently only functionality to convert the Transformer to ONNX. When loading the model again, you can specify backend="ONNX" and the Transformer will work with ONNX instead, without changes in the Dense's.

In the future, I would like to change it so the entire model can be exported, instead of just the "core" of the model. At the time, this was quite tricky with optimum, and I wanted to be compatible with the already existing ONNX exports which also only converted the transformers model and not the pooling/normalization (in the case of dense embedding models).

It's challenging to change this without breaking backwards compatibility, though.

This PR will also break backwards compatibility. model = PyLate("...", backend="onnx") will most likely no longer work. I haven't tested it yet, you can do so by also adding the revision="refs/pr/3" initialization option.

  • Tom Aarsen
This comment has been hidden (marked as Off-Topic)

Hi Tom, thanks for the detail explanation!

I tested backend="onnx" thoroughly (sentence-transformers 5.2.3 + PyLate 1.3.4, both latest):

Test Result
PyLate ColBERT(..., backend="onnx") TypeError, param not supported
SentenceTransformer(..., backend="onnx") Can't load pylate.models.Dense.Dense modules, falls back to plain transformer (256-dim)
Transformer(..., backend="onnx") on snapshot KeyError: 'last_hidden_state' because ONNX output is named "output"

The current model.onnx was added by @NohTow (d3322b2) for next-plaid-onnx, but it only includes 1_Dense (256 to 512), missing 2_Dense (512 to 48), so it outputs 512-dim instead of the 48-dim specified in onnx_config.json. The current ONNX doesn't work for sentence-transformers or next-plaid properly. next-plaid will silently store token embeddings at 512 dimensions instead of 48, increasing storage by ~10x and degrading retrieval quality.

The root cause is that pylate-onnx-export v0.1.0 on PyPI only exports a single Dense layer. This was fixed in v1.0.7 on GitHub. I re-exported using v1.0.7 and confirmed it produces the correct 48-dim output with all Dense layers included.

I've updated this PR with the re-exported files from pylate-onnx-export v1.0.7. Happy to go whichever direction works best here - we could merge as-is, or if you'd prefer to keep model.onnx reserved for sentence-transformers, we can rename the full-pipeline ONNX to something like model_colgrep.onnx / model_colgrep_int8.onnx instead. My goal is simply to get a working ONNX for next-plaid, so any path that gets us there works for me!

Ah yes, my bad I checked that the models were working but did not bench them/catched the dim mismatch, I should have updated the export tooling (although I thought I just installed it on my node, anyways).
I am not an onnx expert, so I would have to check but hopefully @raphaelsty made it so we can load either in ColGrep or PyLate (when we fix the feature).
Maybe it's better if we do the checks before merging again so we do not make a lots of iterative PR on the various PyLate models repo?

Thanks a lot for finding this!!

It's a shame this has been stall for so long.
I guess if we do not find time to dig into this, the best is to merge the PRs as is and fix PyLate ONNX support later

bclavie changed pull request status to merged

Sign up or log in to comment