一(yi)、模型訓練
1. 為保證后續(xu)的模型轉(zhuan)換可能會失敗,選取了ModelZoo-PyTorch中支持的代碼倉庫dbnet來進行訓練,該倉庫地址(zhi)為dbnet作者(zhe)原(yuan)始的代碼倉庫。
2.配(pei)置相關(guan)訓(xun)(xun)練參數,訓(xun)(xun)練一個(ge)dbnet可(ke)用模(mo)型,因使用場(chang)景在印(yin)刷體文字,不需訓(xun)(xun)練多長時間便(bian)可(ke)以得到一個(ge)較好效果的pt格式(shi)模(mo)型文件(jian)。
二、模型(xing)轉換
1. 轉換onnx文(wen)件
python3 convert_onnx.py
2. 轉換om文件
om文件(jian)與(yu)機(ji)器有關,轉(zhuan)換時在昇(sheng)騰910機(ji)器上進(jin)行,使用atc工具(ju)進(jin)行轉(zhuan)換,轉(zhuan)換腳本如(ru)下:
atc --model=db_ic15_resnet50_16.onnx --framework=5 --output=db_ic15_resnet50_16 --input_format=NCHW --input_shape="input:1,3,960,960" --log=error --soc_version=Ascend910B3
其中 soc_version 指定(ding)機器的型(xing)號
三、模型測試
昇騰提供了一個名為benchmark的(de)python API,可以用(yong)來進行(xing)離線模型(xing)(.om模型(xing))推理。
ais_bench推理工具的安裝包括aclruntime包和(he)ais_bench推(tui)理(li)程(cheng)序(xu)包的安裝。安裝方式有多種,可以(yi)使用(yong)whl包進行(xing)安裝。安裝命令如(ru)下:
pip install aclruntime-0.0.2-cp310-cp310-linux_aarch64.whl
pip install ais_bench-0.0.2-py3-none-any.whl
安裝成功后,便(bian)可以引入ais_bench進行推理(li)。
ais_bench的使(shi)用(yong)包括導入依(yi)賴包、加載模(mo)(mo)型、圖像(xiang)預處理、調用接口推(tui)理模(mo)(mo)型得到(dao)輸出(chu)、圖像(xiang)后處理、釋放模(mo)(mo)型占用的內(nei)存幾個步(bu)驟。
以下給出dbnet的完整(zheng)測試代碼:
import argparse
from tqdm import tqdm
import numpy as np
import glob
import cv2
from ais_bench.infer.interface import InferSession
def resize_image(img):
input_w = 960
input_h = 960
h, w, c = img.shape
r_w = input_w / w
r_h = input_h / h
if r_h > r_w:
tw = input_w
th = int(r_w * h)
tx1 = tx2 = 0
ty1 = 0
ty2 = input_h - th
else:
tw = int(r_h * w)
th = input_h
tx1 = 0
tx2 = input_w - tw - tx1
ty1 = ty2 = 0
# Resize the image with long side while maintaining ratio
resized_img = cv2.resize(img, (tw, th))
resized_img = cv2.copyMakeBorder(
resized_img, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, None, (0, 0, 0)
)
return resized_img, tw, th
def transfer_pic(origin_image):
# 圖像預處理
dbnet_input_data, tw, th = resize_image(origin_image)
dbnet_input_data = dbnet_input_data.astype(np.float16)
dbnet_input_data -= np.array(
[122.67891434, 116.66876762, 104.00698793])
dbnet_input_data /= 255.0
dbnet_input_data = dbnet_input_data.transpose([2, 0, 1])
dbnet_input_data = dbnet_input_data[np.newaxis, :]
return dbnet_input_data, tw, th
def get_bin(pred):
pred = pred[0][0]
_bitmap = pred > 0.3
bitmap = _bitmap
height, width = bitmap.shape
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
bitmap = (bitmap * 255).astype(np.uint8)
bitmap = cv2.dilate(bitmap, kernel)
# outs = cv2.findContours(bitmap, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
return bitmap
def main(data_path, npu_session):
files = glob.glob(data_path+'/*.png')+glob.glob(data_path+'/*.jpg')
for data in tqdm(files):
data = cv2.imread(data)
data, new_w,new_h = transfer_pic(data)
npu_result = npu_session.infer(data, "static")
bitmap = get_bin(npu_result[0])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='infer E2E')
parser.add_argument("--data_path", default="img_test", help='data path')
parser.add_argument("--device", default=0, type=int, help='npu device')
parser.add_argument("--om_path", default="db_ic15_resnet50_16.om", help='om path')
flags = parser.parse_args()
db_session = InferSession(flags.device, flags.om_path)
main(flags.data_path,db_session)