YWMditto commited on
Commit
dcb0a1c
·
1 Parent(s): 4a766ad

update readme

Browse files
Files changed (1) hide show
  1. README.md +24 -3
README.md CHANGED
@@ -132,7 +132,7 @@ Notes:
132
 
133
 
134
  ```python
135
- import os
136
  from pathlib import Path
137
  import torch
138
  import torchaudio
@@ -148,6 +148,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-VoiceGenerator"
148
  device = "cuda" if torch.cuda.is_available() else "cpu"
149
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  processor = AutoProcessor.from_pretrained(
152
  pretrained_model_name_or_path,
153
  trust_remote_code=True,
@@ -179,14 +201,13 @@ conversations = [
179
  model = AutoModel.from_pretrained(
180
  pretrained_model_name_or_path,
181
  trust_remote_code=True,
182
- attn_implementation="sdpa",
183
  torch_dtype=dtype,
184
  ).to(device)
185
  model.eval()
186
 
187
  batch_size = 1
188
 
189
- messages = []
190
  save_dir = Path("inference_root")
191
  save_dir.mkdir(exist_ok=True, parents=True)
192
  sample_idx = 0
 
132
 
133
 
134
  ```python
135
+ import importlib.util
136
  from pathlib import Path
137
  import torch
138
  import torchaudio
 
148
  device = "cuda" if torch.cuda.is_available() else "cpu"
149
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
150
 
151
+ def resolve_attn_implementation() -> str:
152
+ # Prefer FlashAttention 2 when package + device conditions are met.
153
+ if (
154
+ device == "cuda"
155
+ and importlib.util.find_spec("flash_attn") is not None
156
+ and dtype in {torch.float16, torch.bfloat16}
157
+ ):
158
+ major, _ = torch.cuda.get_device_capability()
159
+ if major >= 8:
160
+ return "flash_attention_2"
161
+
162
+ # CUDA fallback: use PyTorch SDPA kernels.
163
+ if device == "cuda":
164
+ return "sdpa"
165
+
166
+ # CPU fallback.
167
+ return "eager"
168
+
169
+
170
+ attn_implementation = resolve_attn_implementation()
171
+ print(f"[INFO] Using attn_implementation={attn_implementation}")
172
+
173
  processor = AutoProcessor.from_pretrained(
174
  pretrained_model_name_or_path,
175
  trust_remote_code=True,
 
201
  model = AutoModel.from_pretrained(
202
  pretrained_model_name_or_path,
203
  trust_remote_code=True,
204
+ attn_implementation=attn_implementation,
205
  torch_dtype=dtype,
206
  ).to(device)
207
  model.eval()
208
 
209
  batch_size = 1
210
 
 
211
  save_dir = Path("inference_root")
212
  save_dir.mkdir(exist_ok=True, parents=True)
213
  sample_idx = 0