GENATATOR-GENA-RMT-Large-UNET (Multispecies Gene Segmentation Model)
Overview
GENATATOR-GENA-RMT-Large-UNET is a DNA language model designed for gene segmentation directly from genomic DNA sequences.
The model performs nucleotide-level multilabel classification and predicts five gene structure classes:
| Class | Description |
|---|---|
| 5UTR | 5′ untranslated region |
| exon | exon |
| intron | intron |
| 3UTR | 3′ untranslated region |
| CDS | coding sequence |
The order of output classes in the model is:
["5UTR", "exon", "intron", "3UTR", "CDS"]
The model outputs one logit vector per nucleotide, allowing reconstruction of complete gene structures.
Model
Model name on Hugging Face:
genatator-gena-rmt-large-unet-multispecies-segmentation
Architecture properties:
- backbone: GENA-LM BERT Large
- layers: 24
- hidden size: 1024
- attention heads: 16
- tokenization: GENA tokenization
- segmentation head: RMT encoder + 1D UNET refinement
- output head: linear projection to 5 classes
- context handling: recurrent memory segmentation
- recommended maximum sequence length for practical inference: up to 1,000,000 nucleotides
The architecture combines three components:
GENA-LM backbone (BERT Large) A pretrained genomic language model that produces contextual sequence representations.
Recurrent Memory Transformer (RMT) Long sequences are processed in segments with memory tokens, allowing the model to operate beyond the base backbone context length.
1D UNET segmentation head Token-level representations are repeated to nucleotide resolution and refined with repeated 1D convolutional UNET passes before final projection to the five output classes.
Batch Size Limitation
The current implementation supports batch size = 1 only.
This limitation comes from the 1D convolutional UNET refinement head used after token repetition to nucleotide resolution. Inference is therefore currently performed one sequence at a time.
Context Length
This model is not limited to a fixed short context window in the usual backbone sense because it uses RMT-based segmented processing with memory tokens.
In practice, we currently recommend using sequences of no more than 1,000,000 nucleotides per inference run.
Training Data
This model was fine-tuned on gene sequences only, not on full genomes.
Training data includes:
- mRNA transcripts
- lncRNA transcripts
Dataset characteristics:
- one transcript per gene
- no intergenic regions
- multispecies training dataset
Each training sample corresponds to a single gene sequence.
Usage
from transformers import AutoTokenizer, AutoModelForTokenClassification
repo_id = "shmelev/genatator-gena-rmt-large-unet-multispecies-segmentation"
tokenizer = AutoTokenizer.from_pretrained(
repo_id,
trust_remote_code=True,
)
model = AutoModelForTokenClassification.from_pretrained(
repo_id,
trust_remote_code=True,
)
model.eval()
Example Inference
Note: only one sequence should be passed at a time.
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
repo_id = "shmelev/genatator-gena-rmt-large-unet-multispecies-segmentation"
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModelForTokenClassification.from_pretrained(repo_id, trust_remote_code=True)
sequence = "ACGTACGTACGTACGTACGTACGTACGTACGTACGTACGT"
enc = tokenizer([sequence])
input_ids = torch.tensor(enc["input_ids"])
with torch.no_grad():
outputs = model(input_ids=input_ids)
logits = outputs["logits"]
print("Input shape:", input_ids.shape)
print("Logits shape:", logits.shape)
Example output:
Input shape: torch.Size([1, sequence_length])
Logits shape: torch.Size([1, sequence_length, 5])
Each nucleotide receives 5 logits corresponding to the gene structure classes.
- Downloads last month
- -