原项目 下载models v1.11 因为项目要用到official包里的文件,所以将下载的models压缩包解压后,把其中的official文件夹拷贝到python包路径下,比如:
1 ~/.virtualenvs/deepspeech/lib/python3.5/site-packages
安装official依赖:
1 $ pip install -r official/requirements.txt
训练中文 数据集必须是wav文件,采样率16000,单声道,时长不超过27.0秒。 标签文件格式如下:
1 2 3 wav_filename wav_filesize transcript /home/fanrong/ASR/dataset/data_thchs30/train/A2_160.wav 342044 用 传 统 的 常 规 育 秧 方 式 每 育 一 亩 秧 苗 只 可 栽 插 八 亩 大 田 底 膜 育 秧 则 可 移 栽 五 十 亩 ...
第一行是表头,后面每一行是文件路径、制表符、文件大小、制表符、语音对应的汉字。文件保存为.csv
。 要去掉时长超过27.0秒的音频,可以用如下shell文件:
1 2 3 4 train_file="train_dataset.csv" final_train_file="final_train_dataset.csv" MAX_AUDIO_LEN=27.0 awk -v maxlen="$MAX_AUDIO_LEN" 'BEGIN{FS="\t";} NR==1{print $0} NR>1{cmd="soxi -D "$1""; cmd|getline x; if(x<=maxlen) {print $0}; close(cmd);}' $train_file > $final_train_file
修改原项目/data/vocabulary.txt文件,#表示注释
1 2 3 4 5 6 7 8 9 10 11 12 13 # begin of vocabulary 这 片 种 植 区 的 . . . - # end of vacabulary
第一行是一个空格,最后一行是一个减号。 修改deep_speech.py中
1 greedy_decoder = decoder.DeepSpeechDecoder(speech_labels, blank_index=len(speech_labels))
加了一个blank_index,因为原项目用来识别英文,字典长度28,改成中文后长度发生变化。
开始训练 下面的命令是项目修改后要用的命令:
1 $ python deep_speech.py --train_data_dir=../data_thchs30/ch_train_dataset.csv --eval_data_dir=../data_thchs30/ch_test_dataset.csv --num_gpus=0
开始预测 1 $ python auto_speech_rec.py --pred_data_dir=../data_thchs30/ch_pred_dataset.csv --num_gpus=0
用tfserving部署 因为DeepSpeech2使用了Estimator所以和前一篇文章 中介绍的保存savedmodel的方法有所不同,首先当然要去官方文档 查看如何保存,官方文档是这样说的: 在training的时候需要一个input_fn()
来准备数据给模型使用,在serving的时候类似,需要一个serving_input_receiver_fn()
接收推理请求并做一些处理。该函数具有以下用途:
在graph中为推理请求添加placeholder。
添加将数据从输入格式转换为模型所预期的特征Tensor所需的任何额外操作。
该函数返回一个tf.estimator.export.ServingInputReceiver
对象,该对象会将placeholder和生成的feature tensor打包在一起。 典型的模式是推理请求以序列化tf.Example
的形式到达,因此serving_input_receiver_fn()
创建单个字符串占位符来接收它们。serving_input_receiver_fn() 也负责解析tf.Example,通过向图graph中添加tf.parse_example
操作,并将解析规范传递给tf.parse_example,告诉解析器可能会遇到哪些特征名称以及如何将它们映射到 Tensor。解析规范采用字典的形式,即从特征名称映射到tf.FixedLenFeature
、tf.VarLenFeature
和tf.SparseFeature
。综上所述:
1 2 3 4 5 6 7 8 9 10 11 feature_spec = {'foo': tf.FixedLenFeature(...), 'bar': tf.VarLenFeature(...)} def serving_input_receiver_fn(): """An input receiver that expects a serialized tf.Example.""" serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[default_batch_size], name='input_example_tensor') receiver_tensors = {'examples': serialized_tf_example} features = tf.parse_example(serialized_tf_example, feature_spec) return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
看完官方文档大体知道是什么流程了,但是没有具体的例子,对于tf新手来说很难知道具体怎么操作,比如官方文档中的’序列化tf.Example形式’,从来没有用过。所以还是从github上找了一个例子 。例子中的export.py文件是用来保存savedmodel的,其中实现的serving_input_receiver_fn()并没有用到官方文档中的tf.Example和tf.parse_example等,也不用定义解析规范。只是定义了一个接收到的数据所需的placeholder,并将这个placeholder加工处理为feature tensor,最后把它们打包在一起。 仿照例子中的形式,我在data/dataset.py中添加了serving_input_receiver_fn()函数:
1 2 3 4 5 6 7 def serving_input_receiver_fn(): audio_feature = tf.placeholder(dtype=tf.float32, shape=[None, None, 161, 1], name='features') input_length = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='input_length') label_length = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='label_length') receiver_tensors = {'features':audio_feature, 'input_length':input_length, 'label_length':label_length} features = receiver_tensors return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
其中的placeholder可以参照input_fn(),但是要多一维,因为有instances
批量推理请求的情况。接着需要在auto_speech_rec.py中的run_deep_speech()中添加:
1 2 estimator = tf.estimator.Estimator(...) estimator.export_savedmodel(flags_obj.export_dir, dataset.serving_input_receiver_fn)
运行一遍auto_speech_rec.py即可导出savedmodel形式的模型。最后还需要编写一个client来进行推理请求:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 import os import sys import decoder import requests import numpy as np import data.dataset as dataset import data.featurizer as featurizer from deep_speech import generate_dataset URL = 'http://localhost:8503/v1/models/deepspeech:predict' VOCABULARY = os.path.abspath('data/vocabulary.txt') audio_file = '/home/fanrong1/Project/ASR/dataset/data_thchs30/pred/A12_100.wav' if __name__ == '__main__': audio_featurizer = featurizer.AudioFeaturizer() text_featurizer = featurizer.TextFeaturizer(VOCABULARY) features = dataset._preprocess_audio(audio_file, audio_featurizer, True) predict_request = '{"inputs":{"features":%s, "input_length":[[%s]], "label_length":[[0]]}}' % ([features.tolist()], len(features)) response = requests.post(URL, data=predict_request) #print(response.text) response.raise_for_status() y_pred = response.json()['outputs'] pred = y_pred['probabilities'][0] greedy_decoder = decoder.DeepSpeechDecoder(text_featurizer.speech_labels, blank_index=len(text_featurizer.speech_labels)) decoded_str = greedy_decoder.decode(pred) decoded_str = decoded_str.replace(' ', '').replace('-', '') print(decoded_str)