Flux Attention: Qwen3-4B

Flux Attention is a context-aware framework that dynamically optimizes attention computation at the layer level. By integrating a lightweight Layer Router into frozen pretrained LLMs, the proposed method adaptively routes each layer to Full Attention (FA) or Sparse Attention (SA) based on the input context.

Sample Usage

To use this model, you need to install the Flux Attention library and its dependencies (including Block-Sparse-Attention).

import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_sparse_model(model_path):
    """
    Dynamically loads the correct sparse architecture based on config.
    """
    config_path = f"{model_path}/config.json"
    with open(config_path, "r") as f:
        config_data = json.load(f)

    arch = config_data.get("architectures", [])
    if not arch:
        raise ValueError("No architecture found in config.json")

    arch_name = arch[0]
    print(f"🚀 Detected architecture: {arch_name}")

    # Register custom architectures
    if "PawLlama" in arch_name:
        from fluxattn.training.eval.modeling_flash_llama import (
            PawLlamaForCausalLM, PawLlamaConfig
        )
        AutoModelForCausalLM.register(PawLlamaConfig, PawLlamaForCausalLM)
        model_cls = PawLlamaForCausalLM
        
    elif "PawQwen" in arch_name:
        from fluxattn.training.eval.modeling_flash_qwen import (
            PawQwen3ForCausalLM, PawQwen3Config
        )
        AutoModelForCausalLM.register(PawQwen3Config, PawQwen3ForCausalLM)
        model_cls = PawQwen3ForCausalLM
    else:
        raise ValueError(f"Unsupported architecture: {arch_name}")

    # Load model
    model = model_cls.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    return model

# --- Execution ---
model_id = "QQTang1223/Flux-Attention-Qwen3-4B" # Replace with local path if necessary
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

print("Loading Flux Attention Model...")
model = load_sparse_model(model_id)
model.eval()

# Generate
input_text = "Explain quantum mechanics in one sentence."
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

print("Generating...")
outputs = model.generate(**inputs, max_new_tokens=100)
print("
Output:
" + tokenizer.decode(outputs[0], skip_special_tokens=True))

Citation

@misc{qiu2026fluxattentioncontextawarehybrid,
      title={Flux Attention: Context-Aware Hybrid Attention for Efficient LLMs Inference}, 
      author={Quantong Qiu and Zhiyi Hong and Yi Yang and Haitian Wang and Kebin Liu and Qingqing Dang and Juntao Li and Min Zhang},
      year={2026},
      eprint={2604.07394},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2604.07394}, 
}
Downloads last month
289
Safetensors
Model size
4B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for QQTang1223/full_xattn_Qwen3-4B

Finetuned
Qwen/Qwen3-4B
Finetuned
(572)
this model

Collection including QQTang1223/full_xattn_Qwen3-4B

Paper for QQTang1223/full_xattn_Qwen3-4B