先介绍一下学者使用的运动轨迹预测数据集
Argoverse Motion Forecasting Dataset v1.1
现在Argoverse数据集已经出到v2版本,可以支持windows系统,但大多学者都是用2019发布的Argoverse v1.1,这个版本的api没有提供windows系统的支持,数据集作者说应该是转义字符的问题。(Argoverse v2的Motion Forecasting Dataset更大,全部下载完要50+g)。
https://github.com/argoai/argoverse-api
可以根据上面链接下载Argoverse api,这里都是使用v1.1版本的。Argoverse api v1.1仅支持MacOS和Linux,下载完后,对应下面步骤进行安装:
创建虚拟环境 python版本为3.8。(不懂的可以找我之前发的利用anaconda配置虚拟环境)
进入虚拟环境后,进入当前下载的Argoverse api文件夹中,执行
pip install -e ./
也可以选择性安装mayavi、ffmpeg、Stereo tutorial dependencies。我只安装了mayavi,在安装这个之前先安装pip install PyQt5,再安装pip install mayavi。(这个应该也不用安装的,毕竟是视频流的相关库)
可以从数据集官网下载以下三个文件,用于使用Jupyter Notebook测试是否安装成功。可以打开Argoverse api中的Usage文件夹,打开你想要测试的,比如你想要测试轨迹预测,那就点击关于“Forecasting”的文件,逐个执行就可以了。

再介绍HiVT代码
先上代码链接
https://github.com/ZikangZhou/HiVT
配置环境
conda create -n HiVT python=3.8
conda activate HiVT
# CUDA 10.2
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch
# CUDA 11.1
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
# CPU Only
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cpuonly -c pytorch
conda install pytorch-geometric==1.7.2
conda install pytorch-lightning==1.5.2
根据自己CUDA版本下载对用的pytorch,个人觉得pytorch在1.8.0-1.10.0左右版本应该都差不多,这个CUDA只要你大于11.1都可以向下兼容安装这两个版本的cudatoolkit。
你会发现pytorch-geometric有可能直接安装不了,这是很正常,进阶安装如下:
先进入轮子地址:https://pytorch-geometric.com/whl/
再寻找你的对应pytorch版本和cudatoolkit版本,点进去找到下面whl

然后再 pip install torch-geometric==1.7.2
到此,应该安装完毕了
准备数据集
回到我们的数据集官网,下载

将这些数据集解压放置新建dataset文件夹中,格式如下

Training
训练小一点的模型 HiVT-64,可以适当修改batchsize或者其他参数,实测batchsize=32时候,占用显存才6g+。文章来源:https://www.toymoban.com/news/detail-613673.html
python train.py --root dataset/ --embed_dim 64
训练大一点的模型HiVT-128文章来源地址https://www.toymoban.com/news/detail-613673.html
python train.py --root dataset/ --embed_dim 128
Evaluation
python eval.py --root dataset/ --batch_size 32 --ckpt_path /path/to/your_checkpoint.ckpt
到了这里,关于多目标运动轨迹预测HiVT代码跑通的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!