CodeRankEmbed-flash-attn

A bf16 quantization of nomic-ai/CodeRankEmbed with flash-attention built into a custom modeling_hf_nomic_bert.py shipped in this repo. It is not a finetune โ€” the weights are the original CodeRankEmbed weights cast to bf16 (no further training) โ€” and attention runs via flash_attn varlen instead of the original eager O(seqยฒ) path.

Why

nomic-ai/CodeRankEmbed loads through trust_remote_code, and its attention path is eager only โ€” activation memory grows as batch ร— heads ร— seqยฒ, which OOMs at large batches even though the model is only 137M params. flash_attn's varlen path computes the same attention in O(N) memory by packing unpadded sequences, so the large batches that OOM the eager path run comfortably โ€” with parity embeddings (no quality change). This repo ships that flash path in the modeling file itself, so no runtime patching or post-load hooks are needed.

Behavior

  • Loads bf16 by default. flash_attn requires half precision and the model runs bf16 in any real serving setup, so the weights are stored bf16 and config.json declares torch_dtype: bfloat16. The upstream custom from_pretrained silently dropped torch_dtype and always loaded fp32; the copy in this repo honors it, so the model loads bf16 natively, like any normal HF model. Pass torch_dtype=torch.float32 to load fp32 (note: the stored weights are bf16-precision, so this only widens the dtype, not the precision).
  • CUDA + flash_attn โ†’ flash-varlen path (fast, low VRAM). Attention tensors are cast to bf16 internally as a safety net; with the default bf16 load this is a no-op.
  • CPU, or no flash_attn โ†’ the original eager attention algorithm runs unchanged. Because the model is bf16, eager runs in bf16 here โ€” numerically equivalent to the flash path, just without its memory and throughput wins. The model loads and encodes on any host; the forward selects the path automatically (_FLASH_AVAILABLE and hidden_states.is_cuda).

Usage

Identical to the original. The query prompt must include the task-instruction prefix "Represent this query for searching relevant code: "; documents need no prefix.

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("handwoven8588/CodeRankEmbed-flash-attn", trust_remote_code=True)
queries = ["Represent this query for searching relevant code: Calculate the n-th factorial"]
codes   = ["def fact(n):\n    if n < 0:\n        raise ValueError\n    return 1 if n == 0 else n * fact(n - 1)"]

q = model.encode(queries, normalize_embeddings=True)
d = model.encode(codes,   normalize_embeddings=True)

Parity & performance

The weights are the original CodeRankEmbed weights (bf16-cast), so embeddings match the fp32 original to within bf16 precision. Measured on an RTX 3090 Ti, flash_attn 2.8.3, CLS pooling, L2-normalized, batch size 64:

metric nomic-ai/CodeRankEmbed (fp32, eager) this repo (bf16, flash-varlen)
cosine vs fp32 reference 1.000000 0.9986
peak VRAM (bs=64) 6.7 GB 2.1 GB (โ‰ˆ3.3ร— less)
throughput (bs=64) 52,000 tok/s 162,000 tok/s (โ‰ˆ3.1ร— faster)

Same 512-snippet corpus (~236k tokens, 20โ€“900 tokens each), batch size 64. The flash-varlen path also scales to far larger batches (bs=256 fits in ~7.6 GB) where eager O(seqยฒ) blows up. Parity (>0.997) reproduced on both an RTX 4060 and the 3090 Ti.

What changed vs the source repo

  1. Weights: fp32 โ†’ bf16. flash_attn only accepts half precision and the model runs bf16 in any real serving configuration, so the weights are stored bf16 and (via the load fix below) arrive bf16 โ€” which is simply how this model is used, and removes the need for a post-load dtype cast. Parity-neutral (cosine 0.9986 vs the fp32 original); the smaller download is incidental, not the reason.
  2. from_pretrained dtype fix: the upstream custom from_pretrained instantiated the model fp32 and load_state_dict-ed the checkpoint into fp32 params, ignoring torch_dtype. The copy here adds the standard transformers dtype resolution (explicit arg โ†’ config.torch_dtype โ†’ checkpoint dtype) so the model loads in its declared dtype.
  3. Flash-varlen forward: NomicBertAttention.forward gains an unpad โ†’ flash_attn_varlen_qkvpacked_func(causal=False) โ†’ repad branch (CUDA + flash_attn); the original eager block is kept as the fallback. NomicBertModel.forward skips get_extended_attention_mask on the flash path. Rotary embeddings are applied to the dense [B, S, 3, H, D] tensor before unpadding โ€” the correctness keystone.

License & attribution

MIT โ€” same license as nomic-ai/CodeRankEmbed (see NOTICE). The weights, tokenizer, and the bulk of the modeling file are a verbatim derivative of nomic-ai/CodeRankEmbed; the modeling file derives from Tri Dao's BERT implementation, and CodeRankEmbed was trained by the CoRNStack team (Suresh et al., 2025). Cite their work:

@misc{suresh2025cornstackhighqualitycontrastivedata,
  title  = {CoRNStack: High-Quality Contrastive Data for Text and Code Retrieval},
  author = {Suresh, K N Q and Wang, Xiang and Khan, Saqib and others},
  year   = {2025},
}
Downloads last month
100
Safetensors
Model size
0.1B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for handwoven8588/CodeRankEmbed-flash-attn

Quantized
(16)
this model