Skip to content

Commit c3948ad

Browse files
Spycshpre-commit-ci[bot]
andauthoredNov 25, 2024··
openai compatible for asr/tts (#929)
* openai compatible for asr/tts * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add dep * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent bbca7fd commit c3948ad

File tree

10 files changed

+160
-26
lines changed

10 files changed

+160
-26
lines changed
 

‎comps/asr/whisper/README.md

+9
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ pip install optimum[habana]
3838
cd dependency/
3939
nohup python whisper_server.py --device=hpu &
4040
python check_whisper_server.py
41+
42+
# Or use openai protocol compatible curl command
43+
# Please refer to https://platform.openai.com/docs/api-reference/audio/createTranscription
44+
wget https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav
45+
curl http://localhost:7066/v1/audio/transcriptions \
46+
-H "Content-Type: multipart/form-data" \
47+
-F file="@./sample.wav" \
48+
-F model="openai/whisper-small"
4149
```
4250

4351
### 1.3 Start ASR Service/Test
@@ -114,6 +122,7 @@ docker run -d -p 9099:9099 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$
114122
# curl
115123
http_proxy="" curl http://localhost:9099/v1/audio/transcriptions -XPOST -d '{"byte_str": "UklGRigAAABXQVZFZm10IBIAAAABAAEARKwAAIhYAQACABAAAABkYXRhAgAAAAEA"}' -H 'Content-Type: application/json'
116124

125+
117126
# python
118127
python check_asr_server.py
119128
```

‎comps/asr/whisper/dependency/whisper_model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def __init__(
3030
from transformers import WhisperForConditionalGeneration, WhisperProcessor
3131

3232
self.device = device
33-
asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path)
34-
print("Downloading model: {}".format(asr_model_name_or_path))
35-
self.model = WhisperForConditionalGeneration.from_pretrained(asr_model_name_or_path).to(self.device)
36-
self.processor = WhisperProcessor.from_pretrained(asr_model_name_or_path)
33+
self.asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path)
34+
print("Downloading model: {}".format(self.asr_model_name_or_path))
35+
self.model = WhisperForConditionalGeneration.from_pretrained(self.asr_model_name_or_path).to(self.device)
36+
self.processor = WhisperProcessor.from_pretrained(self.asr_model_name_or_path)
3737
self.model.eval()
3838

3939
self.language = language

‎comps/asr/whisper/dependency/whisper_server.py

+55-3
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,21 @@
55
import base64
66
import os
77
import uuid
8+
from typing import List, Optional, Union
89

910
import uvicorn
10-
from fastapi import FastAPI, Request
11+
from fastapi import FastAPI, File, Form, Request, UploadFile
1112
from fastapi.responses import Response
1213
from pydub import AudioSegment
1314
from starlette.middleware.cors import CORSMiddleware
1415
from whisper_model import WhisperModel
1516

17+
from comps import CustomLogger
18+
from comps.cores.proto.api_protocol import AudioTranscriptionResponse
19+
20+
logger = CustomLogger("whisper")
21+
logflag = os.getenv("LOGFLAG", False)
22+
1623
app = FastAPI()
1724
asr = None
1825

@@ -29,7 +36,7 @@ async def health() -> Response:
2936

3037
@app.post("/v1/asr")
3138
async def audio_to_text(request: Request):
32-
print("Whisper generation begin.")
39+
logger.info("Whisper generation begin.")
3340
uid = str(uuid.uuid4())
3441
file_name = uid + ".wav"
3542
request_dict = await request.json()
@@ -44,13 +51,58 @@ async def audio_to_text(request: Request):
4451
try:
4552
asr_result = asr.audio2text(file_name)
4653
except Exception as e:
47-
print(e)
54+
logger.error(e)
4855
asr_result = e
4956
finally:
5057
os.remove(file_name)
5158
return {"asr_result": asr_result}
5259

5360

61+
@app.post("/v1/audio/transcriptions")
62+
async def audio_transcriptions(
63+
file: UploadFile = File(...), # Handling the uploaded file directly
64+
model: str = Form("openai/whisper-small"),
65+
language: str = Form("english"),
66+
prompt: str = Form(None),
67+
response_format: str = Form("json"),
68+
temperature: float = Form(0),
69+
timestamp_granularities: List[str] = Form(None),
70+
):
71+
logger.info("Whisper generation begin.")
72+
audio_content = await file.read()
73+
# validate the request parameters
74+
if model != asr.asr_model_name_or_path:
75+
raise Exception(
76+
f"ASR model mismatch! Please make sure you pass --model_name_or_path or set environment variable ASR_MODEL_PATH to {model}"
77+
)
78+
asr.language = language
79+
if prompt is not None or response_format != "json" or temperature != 0 or timestamp_granularities is not None:
80+
logger.warning(
81+
"Currently parameters 'language', 'response_format', 'temperature', 'timestamp_granularities' are not supported!"
82+
)
83+
84+
uid = str(uuid.uuid4())
85+
file_name = uid + ".wav"
86+
# Save the uploaded file
87+
with open(file_name, "wb") as buffer:
88+
buffer.write(audio_content)
89+
90+
audio = AudioSegment.from_file(file_name)
91+
audio = audio.set_frame_rate(16000)
92+
93+
audio.export(f"{file_name}", format="wav")
94+
95+
try:
96+
asr_result = asr.audio2text(file_name)
97+
except Exception as e:
98+
logger.error(e)
99+
asr_result = e
100+
finally:
101+
os.remove(file_name)
102+
103+
return AudioTranscriptionResponse(text=asr_result)
104+
105+
54106
if __name__ == "__main__":
55107
parser = argparse.ArgumentParser()
56108
parser.add_argument("--host", type=str, default="0.0.0.0")

‎comps/asr/whisper/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ opentelemetry-sdk
77
prometheus-fastapi-instrumentator
88
pydantic==2.7.2
99
pydub
10+
python-multipart
1011
shortuuid
1112
transformers
1213
uvicorn

‎comps/cores/proto/api_protocol.py

+32
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,38 @@ class AudioChatCompletionRequest(BaseModel):
300300
user: Optional[str] = None
301301

302302

303+
# Pydantic does not support UploadFile directly
304+
# class AudioTranscriptionRequest(BaseModel):
305+
# # Ordered by official OpenAI API documentation
306+
# # default values are same with
307+
# # https://platform.openai.com/docs/api-reference/audio/createTranscription
308+
# file: UploadFile = File(...)
309+
# model: Optional[str] = "openai/whisper-small"
310+
# language: Optional[str] = "english"
311+
# prompt: Optional[str] = None
312+
# response_format: Optional[str] = "json"
313+
# temperature: Optional[str] = 0
314+
# timestamp_granularities: Optional[List] = None
315+
316+
317+
class AudioTranscriptionResponse(BaseModel):
318+
# Ordered by official OpenAI API documentation
319+
# default values are same with
320+
# https://platform.openai.com/docs/api-reference/audio/json-object
321+
text: str
322+
323+
324+
class AudioSpeechRequest(BaseModel):
325+
# Ordered by official OpenAI API documentation
326+
# default values are same with
327+
# https://platform.openai.com/docs/api-reference/audio/createSpeech
328+
input: str
329+
model: Optional[str] = "microsoft/speecht5_tts"
330+
voice: Optional[str] = "default"
331+
response_format: Optional[str] = "mp3"
332+
speed: Optional[float] = 1.0
333+
334+
303335
class ChatMessage(BaseModel):
304336
role: str
305337
content: str

‎comps/tts/gpt-sovits/Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ ENV LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libiomp5.so:/usr/lib/x86_64-linux-gnu/l
1616
ENV MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000"
1717

1818
# Clone source repo
19-
RUN git clone https://github.com/RVC-Boss/GPT-SoVITS.git
19+
RUN git clone --branch openai_compat --single-branch https://github.com/Spycsh/GPT-SoVITS.git
2020
# Download pre-trained models, and prepare env
2121
RUN git clone https://huggingface.co/lj1995/GPT-SoVITS pretrained_models
2222
RUN mv pretrained_models/* GPT-SoVITS/GPT_SoVITS/pretrained_models/ && \

‎comps/tts/gpt-sovits/README.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,15 @@ wget https://github.com/OpenTalker/SadTalker/blob/main/examples/driven_audio/chi
5454

5555
docker cp chinese_poem1.wav gpt-sovits-service:/home/user/chinese_poem1.wav
5656

57-
http_proxy="" curl localhost:9880/change_refer -d '{
57+
curl localhost:9880/change_refer -d '{
5858
"refer_wav_path": "/home/user/chinese_poem1.wav",
5959
"prompt_text": "窗前明月光,疑是地上霜,举头望明月,低头思故乡。",
6060
"prompt_language": "zh"
6161
}'
6262
```
63+
64+
- openai protocol compatible request
65+
66+
```bash
67+
curl localhost:9880/v1/audio/speech -XPOST -d '{"input":"你好呀,你是谁. Hello, who are you?"}' -H 'Content-Type: application/json' --output speech.mp3
68+
```

‎comps/tts/speecht5/README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,11 @@ docker run -p 9088:9088 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$htt
8585
#### 2.2.3 Test
8686

8787
```bash
88-
# curl
8988
curl http://localhost:7055/v1/tts -XPOST -d '{"text": "Who are you?"}' -H 'Content-Type: application/json'
9089

90+
# openai protocol compatible
91+
# voice can be 'male' or 'default'
92+
curl http://localhost:7055/v1/audio/speech -XPOST -d '{"input":"Who are you?", "voice": "male"}' -H 'Content-Type: application/json' --output speech.wav
93+
9194
curl http://localhost:9088/v1/audio/speech -XPOST -d '{"text": "Who are you?"}' -H 'Content-Type: application/json'
9295
```

‎comps/tts/speecht5/dependency/speecht5_model.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,34 @@ def __init__(self, device="cpu"):
1717

1818
adapt_transformers_to_gaudi()
1919

20-
model_name_or_path = "microsoft/speecht5_tts"
20+
self.model_name_or_path = "microsoft/speecht5_tts"
2121
vocoder_model_name_or_path = "microsoft/speecht5_hifigan"
22-
self.model = SpeechT5ForTextToSpeech.from_pretrained(model_name_or_path).to(device)
22+
self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name_or_path).to(device)
2323
self.model.eval()
24-
self.processor = SpeechT5Processor.from_pretrained(model_name_or_path, normalize=True)
24+
self.processor = SpeechT5Processor.from_pretrained(self.model_name_or_path, normalize=True)
2525
self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_model_name_or_path).to(device)
2626
self.vocoder.eval()
2727

2828
# fetch default speaker embedding
29-
if os.path.exists("spk_embed_default.pt"):
30-
self.default_speaker_embedding = torch.load("spk_embed_default.pt")
31-
else:
32-
try:
29+
try:
30+
for spk_embed in ["spk_embed_default.pt", "spk_embed_male.pt"]:
31+
if os.path.exists(spk_embed):
32+
continue
33+
3334
p = subprocess.Popen(
3435
[
3536
"curl",
3637
"-O",
3738
"https://raw.githubusercontent.com/intel/intel-extension-for-transformers/main/"
38-
"intel_extension_for_transformers/neural_chat/assets/speaker_embeddings/"
39-
"spk_embed_default.pt",
39+
"intel_extension_for_transformers/neural_chat/assets/speaker_embeddings/" + spk_embed,
4040
]
4141
)
42+
4243
p.wait()
43-
self.default_speaker_embedding = torch.load("spk_embed_default.pt")
44-
except Exception as e:
45-
print("Warning! Need to prepare speaker_embeddings, will use the backup embedding.")
46-
self.default_speaker_embedding = torch.zeros((1, 512))
44+
self.default_speaker_embedding = torch.load("spk_embed_default.pt")
45+
except Exception as e:
46+
print("Warning! Need to prepare speaker_embeddings, will use the backup embedding.")
47+
self.default_speaker_embedding = torch.zeros((1, 512))
4748

4849
if self.device == "hpu":
4950
# do hpu graph warmup with variable inputs
@@ -87,7 +88,9 @@ def _warmup_speecht5_hpu_graph(self):
8788
"OPEA is an ecosystem orchestration framework to integrate performant GenAI technologies & workflows leading to quicker GenAI adoption and business value."
8889
)
8990

90-
def t2s(self, text):
91+
def t2s(self, text, voice="default"):
92+
if voice == "male":
93+
self.default_speaker_embedding = torch.load("spk_embed_male.pt")
9194
if self.device == "hpu":
9295
# See https://github.com/huggingface/optimum-habana/pull/824
9396
from optimum.habana.utils import set_seed

‎comps/tts/speecht5/dependency/speecht5_server.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@
33

44
import argparse
55
import base64
6+
import os
67

78
import soundfile as sf
89
import uvicorn
910
from fastapi import FastAPI, Request
10-
from fastapi.responses import Response
11+
from fastapi.responses import Response, StreamingResponse
1112
from speecht5_model import SpeechT5Model
1213
from starlette.middleware.cors import CORSMiddleware
1314

15+
from comps import CustomLogger
16+
from comps.cores.proto.api_protocol import AudioSpeechRequest
17+
18+
logger = CustomLogger("speecht5")
19+
logflag = os.getenv("LOGFLAG", False)
20+
1421
app = FastAPI()
1522
tts = None
1623

@@ -27,7 +34,7 @@ async def health() -> Response:
2734

2835
@app.post("/v1/tts")
2936
async def text_to_speech(request: Request):
30-
print("SpeechT5 generation begin.")
37+
logger.info("SpeechT5 generation begin.")
3138
request_dict = await request.json()
3239
text = request_dict.pop("text")
3340

@@ -40,6 +47,27 @@ async def text_to_speech(request: Request):
4047
return {"tts_result": b64_str}
4148

4249

50+
@app.post("/v1/audio/speech")
51+
async def audio_speech(request: AudioSpeechRequest):
52+
logger.info("SpeechT5 generation begin.")
53+
# validate the request parameters
54+
if request.model != tts.model_name_or_path:
55+
raise Exception("TTS model mismatch! Currently only support model: microsoft/speecht5_tts")
56+
if request.voice not in ["default", "male"] or request.speed != 1.0:
57+
logger.warning("Currently parameter 'speed' can only be 1.0 and 'voice' can only be default or male!")
58+
59+
speech = tts.t2s(request.input, voice=request.voice)
60+
61+
tmp_path = "tmp.wav"
62+
sf.write(tmp_path, speech, samplerate=16000)
63+
64+
def audio_gen():
65+
with open(tmp_path, "rb") as f:
66+
yield from f
67+
68+
return StreamingResponse(audio_gen(), media_type=f"audio/{request.response_format}")
69+
70+
4371
if __name__ == "__main__":
4472
parser = argparse.ArgumentParser()
4573
parser.add_argument("--host", type=str, default="0.0.0.0")

0 commit comments

Comments
 (0)
Please sign in to comment.