迁移学习是一种在新数据集上重新训练 DNN 模型的技术,它比从头开始训练网络花费的时间更少。通过迁移学习,可以微调预训练模型的权重以对自定义数据集进行分类。在这些示例中,我们将使用 ResNet-18 和 SSD-Mobilenet 网络,尽管您也可以尝试其他网络。

尽管由于经常使用大型数据集和相关的计算需求,训练通常在具有离散 GPU 的 PC、服务器或云实例上执行,但通过使用迁移学习,我们能够重新训练 Jetson 上的各种网络以 开始训练和部署我们自己的 DNN 模型。
PyTorch 是我们将使用的机器学习框架,除了用于收集和标记您自己的训练数据集的基于相机的工具外,还提供了示例数据集和训练脚本以供使用。
一、安装Pytorch
在之前的教程中,Pytorch应该已经安装了,如果没有的话,运行以下命令行安装:
$ cd jetson-inference/build
$ ./install-pytorch.sh

注:自动 PyTorch 安装工具需要 JetPack 4.2 或更高版本。如果你想做对象检测训练,你应该使用 JetPack 4.4 或更新版本并安装 PyTorch for Python 3.6。
二、建立训练所需文件夹
在训练目录下建立以下文件和文件夹

其中:lables.txt用来存储对象的标签,train文件夹存放训练用的原始数据(照片),val文件夹存放验证数据,test文件夹存放模型训练完的测试数据。样品数量train>val>test,比例大致为100:20:5。
三、使用Jetson Nano自带图片捕捉工具
cd到jetson-inference/tools文件夹,输入以下命令启动camera-capture工具:
camera-capture --width=640 --height=480 --camera=/dev/video0

Dataset Type选择“Classification”,Dataset Path为存放训练照片的顶级目录,Class Lables为labels.txt所在目录。Current Set为train, val或test。

接下来就是选择不同文件夹对目标进行拍照。
三、模型训练
cd到/jetson-inference/python/training/classification文件夹,运行以下命令:
python3 train.py --model-dir=myModel ~/jetson-inference/myTrain
将上面的“myModel”和“myTrain”替换了自己的模型存储目录和训练数据模具。
一切正常的话pyTorch就会开始运行迁移训练:

经过35个Epoch,训练结束,模型储存在指定的目录下:

因为模型是由pyTorch生成的,接下来要将模型转换为Jeston nano可以识别的模型。同样是在classification目录下运行以下命令:
python3 onnx_export.py --model-dir=myModel
将上面的”myModel”替换为自己的模型目录。
转换结束,生成了新的onnx格式的模型:

四、在IDE中运行训练完的模型
程序代码同之前的图像识别一样,只是要将模型换成训练好的模型:
net=jetson.inference.imageNet('googlenet',['--model=/home/jetson/jetson-inference/python/training/classification/myModel/resnet18.onnx','--input_blob=input_0','--output_blob=output_0','--labels=/home/jetson/jetson-inference/myTrain/labels.txt'])
同样需要将上面的”myModel”和“myTrain”替换为自己的模型和训练目录。
示例代码如下:
import jetson.inference
import jetson.utils
import cv2import numpy as np
import time
width=640
height=480
cam=jetson.utils.gstCamera(width,height,'/dev/video0')
net=jetson.inference.imageNet('googlenet',['--model=/home/jetson/jetson-inference/python/training/classification/myModel/resnet18.onnx','--input_blob=input_0','--output_blob=output_0','--labels=/home/jetson/jetson-inference/myTrain/labels.txt'])
font=cv2.FONT_HERSHEY_SIMPLEX
timeMark=time.time()
fpsFilter=0
while True:
frame,width,height=cam.CaptureRGBA(zeroCopy=1)
classID,confidnece=net.Classify(frame,width,height)
item=net.GetClassDesc(classID)
dt=time.time()-timeMark
fps=1/dt
fpsFilter=.95*fpsFilter+.05*fps
timeMark=time.time()
frame=jetson.utils.cudaToNumpy(frame,width,height,4)
frame=cv2.cvtColor(frame,cv2.COLOR_RGBA2BGR).astype(np.uint8)
cv2.putText(frame,str(round(fpsFilter,1))+' fps '+item,(0,30),font,1,(0,0,255),2)
cv2.imshow('recCam',frame)
cv2.moveWindow('recCam',0,0)
if cv2.waitKey(1)==ord('q'): breakcam.release()
cv2.destroyAllWindows()
运行结果:
