BruceFan's Blog

Stay hungry, stay foolish

0%

DeepSpeech2使用

原项目
下载models v1.11
因为项目要用到official包里的文件,所以将下载的models压缩包解压后,把其中的official文件夹拷贝到python包路径下,比如:

1
~/.virtualenvs/deepspeech/lib/python3.5/site-packages

安装official依赖:

1
$ pip install -r official/requirements.txt

训练中文

数据集必须是wav文件,采样率16000,单声道,时长不超过27.0秒。
标签文件格式如下:

1
2
3
wav_filename    wav_filesize    transcript
/home/fanrong/ASR/dataset/data_thchs30/train/A2_160.wav 342044 用 传 统 的 常 规 育 秧 方 式 每 育 一 亩 秧 苗 只 可 栽 插 八 亩 大 田 底 膜 育 秧 则 可 移 栽 五 十 亩
...

第一行是表头,后面每一行是文件路径、制表符、文件大小、制表符、语音对应的汉字。文件保存为.csv
要去掉时长超过27.0秒的音频,可以用如下shell文件:

1
2
3
4
train_file="train_dataset.csv"
final_train_file="final_train_dataset.csv"
MAX_AUDIO_LEN=27.0
awk -v maxlen="$MAX_AUDIO_LEN" 'BEGIN{FS="\t";} NR==1{print $0} NR>1{cmd="soxi -D "$1""; cmd|getline x; if(x<=maxlen) {print $0}; close(cmd);}' $train_file > $final_train_file

修改原项目/data/vocabulary.txt文件,#表示注释

1
2
3
4
5
6
7
8
9
10
11
12
13
# begin of vocabulary







.
.
.
-
# end of vacabulary

第一行是一个空格,最后一行是一个减号。
修改deep_speech.py中

1
greedy_decoder = decoder.DeepSpeechDecoder(speech_labels, blank_index=len(speech_labels))

加了一个blank_index,因为原项目用来识别英文,字典长度28,改成中文后长度发生变化。

开始训练

下面的命令是项目修改后要用的命令:

1
$ python deep_speech.py --train_data_dir=../data_thchs30/ch_train_dataset.csv --eval_data_dir=../data_thchs30/ch_test_dataset.csv --num_gpus=0

开始预测

1
$ python auto_speech_rec.py --pred_data_dir=../data_thchs30/ch_pred_dataset.csv --num_gpus=0

用tfserving部署

因为DeepSpeech2使用了Estimator所以和前一篇文章中介绍的保存savedmodel的方法有所不同,首先当然要去官方文档查看如何保存,官方文档是这样说的:
在training的时候需要一个input_fn()来准备数据给模型使用,在serving的时候类似,需要一个serving_input_receiver_fn()接收推理请求并做一些处理。该函数具有以下用途:

  • 在graph中为推理请求添加placeholder。
  • 添加将数据从输入格式转换为模型所预期的特征Tensor所需的任何额外操作。

该函数返回一个tf.estimator.export.ServingInputReceiver对象,该对象会将placeholder和生成的feature tensor打包在一起。
典型的模式是推理请求以序列化tf.Example的形式到达,因此serving_input_receiver_fn()创建单个字符串占位符来接收它们。serving_input_receiver_fn() 也负责解析tf.Example,通过向图graph中添加tf.parse_example操作,并将解析规范传递给tf.parse_example,告诉解析器可能会遇到哪些特征名称以及如何将它们映射到 Tensor。解析规范采用字典的形式,即从特征名称映射到tf.FixedLenFeaturetf.VarLenFeaturetf.SparseFeature。综上所述:

1
2
3
4
5
6
7
8
9
10
11
feature_spec = {'foo': tf.FixedLenFeature(...),
'bar': tf.VarLenFeature(...)}

def serving_input_receiver_fn():
"""An input receiver that expects a serialized tf.Example."""
serialized_tf_example = tf.placeholder(dtype=tf.string,
shape=[default_batch_size],
name='input_example_tensor')
receiver_tensors = {'examples': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

看完官方文档大体知道是什么流程了,但是没有具体的例子,对于tf新手来说很难知道具体怎么操作,比如官方文档中的’序列化tf.Example形式’,从来没有用过。所以还是从github上找了一个例子。例子中的export.py文件是用来保存savedmodel的,其中实现的serving_input_receiver_fn()并没有用到官方文档中的tf.Example和tf.parse_example等,也不用定义解析规范。只是定义了一个接收到的数据所需的placeholder,并将这个placeholder加工处理为feature tensor,最后把它们打包在一起。
仿照例子中的形式,我在data/dataset.py中添加了serving_input_receiver_fn()函数:

1
2
3
4
5
6
7
def serving_input_receiver_fn():
audio_feature = tf.placeholder(dtype=tf.float32, shape=[None, None, 161, 1], name='features')
input_length = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='input_length')
label_length = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='label_length')
receiver_tensors = {'features':audio_feature, 'input_length':input_length, 'label_length':label_length}
features = receiver_tensors
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

其中的placeholder可以参照input_fn(),但是要多一维,因为有instances批量推理请求的情况。接着需要在auto_speech_rec.py中的run_deep_speech()中添加:

1
2
estimator = tf.estimator.Estimator(...)
estimator.export_savedmodel(flags_obj.export_dir, dataset.serving_input_receiver_fn)

运行一遍auto_speech_rec.py即可导出savedmodel形式的模型。最后还需要编写一个client来进行推理请求:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
import sys
import decoder
import requests
import numpy as np
import data.dataset as dataset
import data.featurizer as featurizer
from deep_speech import generate_dataset

URL = 'http://localhost:8503/v1/models/deepspeech:predict'
VOCABULARY = os.path.abspath('data/vocabulary.txt')
audio_file = '/home/fanrong1/Project/ASR/dataset/data_thchs30/pred/A12_100.wav'

if __name__ == '__main__':
audio_featurizer = featurizer.AudioFeaturizer()
text_featurizer = featurizer.TextFeaturizer(VOCABULARY)
features = dataset._preprocess_audio(audio_file, audio_featurizer, True)
predict_request = '{"inputs":{"features":%s, "input_length":[[%s]], "label_length":[[0]]}}' % ([features.tolist()], len(features))
response = requests.post(URL, data=predict_request)
#print(response.text)
response.raise_for_status()
y_pred = response.json()['outputs']
pred = y_pred['probabilities'][0]

greedy_decoder = decoder.DeepSpeechDecoder(text_featurizer.speech_labels, blank_index=len(text_featurizer.speech_labels))
decoded_str = greedy_decoder.decode(pred)
decoded_str = decoded_str.replace(' ', '').replace('-', '')
print(decoded_str)