razlapid/sae-gemma-3-4b-it-classifier
Multi-head linear safety classifier trained on SAE features from google/gemma-3-4b-it.
Model Details
- Backbone:
google/gemma-3-4b-it - Layer: 17 (resid_post)
- d_model: 2560
- Classification heads: harmfulness, prompt_injection
- Architecture: LayerNorm โ Dropout โ N x Linear(d_sae, 1)
- Training samples: 35
Results
| Pooling | Val Mean AUC | Test Mean AUC |
|---|---|---|
| mean | 0.4722 | 0.6250 |
| last_token | 0.7778 | 0.6875 |
| max | 0.7333 | 0.9444 |
| topk_max | 0.4778 | 0.6319 |
Usage
from model import MultiHeadProbe
probe = MultiHeadProbe.from_checkpoint("probe_mean/probe_best.pt", device="cuda")
# Feed SAE-encoded, pooled features [batch, d_sae]
probs = probe.predict_proba(features)
Pipeline
Trained with SAE Guard Pipeline.
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support