razlapid/sae-gemma-3-4b-it-classifier

Multi-head linear safety classifier trained on SAE features from google/gemma-3-4b-it.

Model Details

  • Backbone: google/gemma-3-4b-it
  • Layer: 17 (resid_post)
  • d_model: 2560
  • Classification heads: harmfulness, prompt_injection
  • Architecture: LayerNorm โ†’ Dropout โ†’ N x Linear(d_sae, 1)
  • Training samples: 35

Results

Pooling Val Mean AUC Test Mean AUC
mean 0.4722 0.6250
last_token 0.7778 0.6875
max 0.7333 0.9444
topk_max 0.4778 0.6319

Usage

from model import MultiHeadProbe

probe = MultiHeadProbe.from_checkpoint("probe_mean/probe_best.pt", device="cuda")
# Feed SAE-encoded, pooled features [batch, d_sae]
probs = probe.predict_proba(features)

Pipeline

Trained with SAE Guard Pipeline.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support