地平線 3D 目標(biāo)檢測 Bevformer 參考算法 V2.0
簡介
BEVFormer 是當(dāng)前熱門的自動駕駛系統(tǒng)中的 3D 視覺感知任務(wù)模型。BEVFormer 是一個(gè)端到端的框架,BEVFormer 可以直接從原始圖像數(shù)據(jù)生成 BEV 特征,無需依賴于傳統(tǒng)的圖像處理流程。它通過利用 Transformer 架構(gòu)和注意力機(jī)制,有效地從多攝像頭圖像中學(xué)習(xí)生成高質(zhì)量的鳥瞰圖(Bird's-Eye-View, BEV)特征表示。相較于其他的 BEV 轉(zhuǎn)換方式:
時(shí)空注意力機(jī)制:模型結(jié)合了空間交叉注意力(Spatial Cross-Attention, SCA)和時(shí)間自注意力(Temporal Self-Attention, TSA),使網(wǎng)絡(luò)能夠同時(shí)考慮空間和時(shí)間維度上的信息。融合歷史 bev 特征來提升預(yù)設(shè)的 BEV 空間中的 query 的自學(xué)能力,得到 bev 特征。
Deformable attn:通過對每個(gè)目標(biāo)生成幾個(gè)采樣點(diǎn)和采樣點(diǎn)的 offset 來提取采樣點(diǎn)周圍的重要特征,即只關(guān)注和目標(biāo)相關(guān)的特征,減少計(jì)算量。
transformer 架構(gòu):能夠有效捕捉序列中的長期依賴關(guān)系,適用于處理圖像序列。
模型參數(shù):
性能精度表現(xiàn):
模型介紹
·公版 BEVFormer 模型主要可以分為以下幾個(gè)關(guān)鍵部分:
Backbone 網(wǎng)絡(luò):用于從多視角攝像頭圖像中提取特征,本文為 tiny 版本,因此為 ResNet50。
時(shí)空特征提取:BEVFormer 通過引入時(shí)間和空間特征來學(xué)習(xí) BEV 特征。具體來說,模型包括:
Temporal Self-Attention(時(shí)間自注意力):利用前一時(shí)刻的 BEV 特征作為歷史特征,通過自注意力機(jī)制來計(jì)算當(dāng)前時(shí)刻的 BEV 特征。
Spatial Cross-Attention(空間交叉注意力):進(jìn)行空間特征注意力,融合多視角圖像特征。
Deformable Attention(可變形注意力):BEVFormer 使用可變形注意力機(jī)制來加速運(yùn)算,提高模型對不同視角圖像特征的適應(yīng)性。
BEV 特征生成:通過時(shí)空特征的融合,完成環(huán)視圖像特征向 BEV 特征的建模。
Decoder:設(shè)計(jì)用于 3D 物體檢測的端到端網(wǎng)絡(luò)結(jié)構(gòu),基于 2D 檢測器 Deformable DETR 進(jìn)行改進(jìn),以適應(yīng) 3D 空間的檢測任務(wù)。
公版 bevformer 在 征程 6 上部署相比于 征程 5 來說更簡單了,需要考慮的因素更少。征程 6 對非 4 維的支持可以和 4 維的同等效率,因此 征程 6 支持公版的注意力實(shí)現(xiàn),不再限制維度,因此無需對維度做 Reshape,可直接支持公版寫法。但需注意的是公版的 bev_mask 會導(dǎo)致動態(tài) shape。征程 6 不支持動態(tài)輸入,因此 bev_mask 無法使用。在精度上,我們修復(fù)了公版的 bug 已獲得了精度上的提升,同時(shí)通過對關(guān)鍵層做 int16 的量化精度配置以保障 1%以內(nèi)的量化精度損失。
下面將部署優(yōu)化對應(yīng)的改動點(diǎn)以及量化配置依次說明。
性能優(yōu)化改動點(diǎn) 1:
將 attention 層的 mean 替換為 conv 計(jì)算,使性能上獲得提升。
/usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/attention.py
self.query_reduce_mean = nn.Conv2d(
self.num_bev_queue * self.reduce_align_num,
self.reduce_align_num,
1,
bias=False,)
# init query_reduce_mean weight
query_reduce_mean_weight = torch.zeros(
self.query_reduce_mean.weight.size(),
dtype=self.query_reduce_mean.weight.dtype,
)
for i in range(self.reduce_align_num):
for j in range(self.num_bev_queue):
query_reduce_mean_weight[i, j * self.reduce_align_num + i] = (
1 / self.num_bev_queue
)
self.query_reduce_mean.weight = torch.nn.Parameter(
query_reduce_mean_weight, requires_grad=False
)
改動點(diǎn) 2:
公版中,在 Encoder 的空間融合模塊,會根據(jù) bev_mask 計(jì)算有效的 query 和 reference_points,輸出 queries_rebatch 和 reference_points_rebatch,作用為減少交互的數(shù)據(jù)量,提升模型運(yùn)行性能。對于稀疏的 query 做 crossattn 后再將 query 放回到 bev_feature 中。
以上提取稀疏 query 步驟的主要算子為 gather,放回 bev_feature 步驟的主要算子為 scatter。由于工具鏈對這兩個(gè)算子暫未支持(gather 算子 930 已支持)而且 bev_mask 為動態(tài)的,為了提升模型的運(yùn)行性能,工具鏈提供了 gridsample 算子的替換方式,index 計(jì)算只與內(nèi)外參有關(guān),因此作為前處理,將計(jì)算好的 index 作為模型輸入即可。
gather
gather 為根據(jù) bevmask 來提取稀疏 query,降低 cross attn 的數(shù)據(jù)量,提升運(yùn)行效率。
代碼路徑:<code>/usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/<span style="caret-color: #000000; color: #000000; font-family: monospace; font-size: medium; font-style: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: auto; text-align: start; text-indent: 0px; text-transform: none; white-space: normal; widows: auto; word-spacing: 0px; -webkit-text-stroke-width: 0px; background-color: #e8e8e8; text-decoration: none; display: inline !important; float: none;">view_transformer.py</span>
reference_points_cam = torch.clamp(
reference_points_cam, min=-2.1, max=2.1
)
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
bev_mask_ori = bev_mask.clone()
max_len = self.virtual_bev_h * self.virtual_bev_w
queries_rebatch_grid = reference_points_cam.new_zeros(
[B * self.numcam, self.virtual_bev_h, self.virtual_bev_w, 2]
)
for camera_idx, mask_per_img_bs in enumerate(bev_mask):
for bs_id, mask_per_img in enumerate(mask_per_img_bs):
temp_grid = (
torch.zeros(
(max_len, 2),
device=queries_rebatch_grid.device,
dtype=torch.float32,
)
- 1.5
)
index_query_per_img = (
mask_per_img.sum(-1).nonzero().squeeze(-1)
)
num_bev_points = index_query_per_img.shape[0]
camera_idx_tensor_x = index_query_per_img % self.bev_w
camera_idx_tensor_y = index_query_per_img // self.bev_w
index_grid = torch.stack(
[
camera_idx_tensor_x / (self.bev_w - 1),
camera_idx_tensor_y / (self.bev_h - 1),
],
dim=-1,
)
index_grid = index_grid * 2 - 1
temp_grid[:num_bev_points] = index_grid
temp_grid = temp_grid.reshape(
self.virtual_bev_h, self.virtual_bev_w, 2
)
queries_rebatch_grid[
bs_id * self.numcam + camera_idx
] = temp_grid
reference_points_rebatch = (
reference_points_cam.flatten(-2)
.permute(1, 0, 3, 2)
.flatten(0, 1)
.reshape(B * self.numcam, D * 2, self.bev_h, self.bev_w)
)
reference_points_rebatch = (
F.grid_sample(
reference_points_rebatch,
queries_rebatch_grid,
mode="nearest",
align_corners=True,
)
.flatten(-2)
.permute(0, 2, 1)
.reshape(B * self.numcam, max_len, D, 2)
)
query_rebatch
queries_rebatch = (
query.unsqueeze(1)
.repeat(1, self.num_cams, 1, 1)
.reshape(
bs * self.num_cams, self.bev_h, self.bev_w, self.embed_dims
)
.permute(0, 3, 1, 2)
)
queries_rebatch = F.grid_sample(
queries_rebatch,
queries_rebatch_grid,
mode="nearest",
align_corners=True,
)
queries_rebatch = queries_rebatch.flatten(-2).permute(0, 2, 1)
reference_points_rebatch = reference_points_rebatch.flatten(
-2
).unsqueeze(-2)
scatter
scatter 操作對經(jīng)過 deformable_attention 后的 query 放入到 bevfeature 中,然后求平均。
代碼路徑為:/usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/attention.py
slots = self.restore_outputs(
restore_bev_grid,
queries_out,
bev_pillar_counts,
bs,
queries_rebatch_grid,
)
def restore_outputs(
self,
restore_bev_grid: Tensor,
queries_out: Tensor,
counts: Tensor,
bs: int,
queries_rebatch_grid: Tensor,
):
"""Restore outputs to bev feature."""
queries_out = queries_out.reshape(
bs, self.num_cams, self.embed_dims, -1
)
queries_out = queries_out.permute(0, 2, 1, 3)
queries_out = queries_out.reshape(
bs,
self.embed_dims,
self.num_cams * queries_rebatch_grid.shape[1],
queries_rebatch_grid.shape[2],
)
bev_queries = F.grid_sample(
queries_out, restore_bev_grid, mode="nearest", align_corners=True
)
bev_queries = bev_queries.reshape(bs, -1, self.bev_h, self.bev_w)
slots = self.query_reduce_sum(bev_queries).flatten(-2).permute(0, 2, 1)
slots = self.mul_pillarweight.mul(slots, counts)
return slots
其中 restore_bev_grid,根據(jù) bevmask 反算回 bev_feature 的位置:
restore_bev_grid = (精度優(yōu)化浮點(diǎn)精度
reference_points_cam.new_zeros(
B, self.max_camoverlap_num * self.bev_h, self.bev_w, 2
)
- 1.5
)
for bs_id, bev_mask_ in enumerate(bev_mask):
bev_pillar_num_map = torch.zeros(
(self.bev_h, self.bev_w), device=bev_mask_.device
)
count = bev_mask_.sum(-1) > 0
camera_idxs, bev_pillar_idxs = torch.where(count)
camera_idx_offset = 0
for cam_id in range(self.numcam):
camera_idx = torch.where(camera_idxs == cam_id)
bev_pillar_idx_cam = bev_pillar_idxs[camera_idx[0]]
num_camera_idx = len(camera_idx[0])
camera_idx_tmp = camera_idx[0] - camera_idx_offset
camare_tmp_idx_x = camera_idx_tmp % self.virtual_bev_w
camare_tmp_idx_y = camera_idx_tmp // self.virtual_bev_w
grid_x = camare_tmp_idx_x
grid_y = cam_id * self.virtual_bev_h + camare_tmp_idx_y
bev_pillar_idx_cam_x = bev_pillar_idx_cam % self.bev_w
bev_pillar_idx_cam_y = bev_pillar_idx_cam // self.bev_w
bev_pillar_num_map_tmp = bev_pillar_num_map[
bev_pillar_idx_cam_y, bev_pillar_idx_cam_x
]
grid_h = (
bev_pillar_num_map_tmp * self.bev_h + bev_pillar_idx_cam_y
).to(torch.int64)
grid_w = (bev_pillar_idx_cam_x).to(torch.int64)
restore_bev_grid[bs_id, grid_h, grid_w, 0] = grid_x / (
self.virtual_bev_w - 1
)
restore_bev_grid[bs_id, grid_h, grid_w, 1] = grid_y / (
self.numcam * self.virtual_bev_h - 1
)
bev_pillar_num_map[
bev_pillar_idx_cam_y, bev_pillar_idx_cam_x
] = (
bev_pillar_num_map[
bev_pillar_idx_cam_y, bev_pillar_idx_cam_x
]
+ 1
)
camera_idx_offset = camera_idx_offset + num_camera_idx
restore_bev_grid = restore_bev_grid * 2 - 1
改動點(diǎn) 3:
公版通過 can_bus 初始化 ref 來做時(shí)序融合,然而這個(gè)時(shí)候 bev feat 并沒有對齊,在 attention 計(jì)算時(shí)不能簡單的 concat 起來。因此我們換了一種時(shí)序?qū)R的方式,通過前后兩幀的 ego2global 坐標(biāo)系轉(zhuǎn)換矩陣將當(dāng)前幀的 bev 特征和上一幀對齊,此時(shí) ref 都是一樣的。(非 征程 6 不支持,為公版 bug),精度上獲得提升。
/usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/view_transformer.py` `get_prev_bev` `get_fusion_ref
pre_scene = prev_meta["scene_token"]
for i in range(bs):
if pre_scene[i] != cur_meta["meta"][i]["scene"]:
prev_bev[i] = torch.zeros(
(self.bev_h * self.bev_w, self.embed_dims),
dtype=torch.float32,
device=device,
)
##公版:
shift_ref_2d = ref_2d.clone()
shift_ref_2d += shift[:, None, None, :]
bs, len_bev, num_bev_level, _ = ref_2d.shape
hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
bs*2, len_bev, num_bev_level, 2)
##地平線版本
shift_ref_2d = ref_2d.clone()
bs, len_bev, num_bev_level, _ = ref_2d.shape
hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
bs * 2, len_bev, num_bev_level, 2
)
改動點(diǎn) 4:
修復(fù)了個(gè) tsa 公版的 batchsize 不等于 1 的 bug。
量化精度為量化精度保證,我們將以下的算子配置為 int16 或 int32 輸出:
view_transformer:輸入節(jié)點(diǎn)做 int16 量化:
int16_models = [
self.quant_hybird_ref_2d,
self.quant_norm_coords,
self.quant_restore_bev_grid,
self.quant_reference_points_rebatch,
self.quant_queries_rebatch_grid,
]
for m in int16_models:
m.qconfig = qconfig_manager.get_qconfig(
activation_qat_qkwargs={"dtype": qint16},
activation_calibration_qkwargs={
"dtype": qint16,
},
activation_calibration_observer="mix",
)
attention 層:最后兩個(gè) conv 和 add 開啟 int16
def set_qconfig(self) -> None:
"""Set the quantization configuration."""
from hat.utils import qconfig_manager
int16_module = [
self.output_proj,
self.add_res,
]
decoder 層:cls_branches、reg_branches 的 conv 配置為 int32 輸出;sigmoid 和 reference_points 配置為 int16
def set_qconfig(self) -> None: """Set the quantization configuration.""" from hat.utils import qconfig_manager for _, m in enumerate(self.cls_branches): m[0].qconfig = qconfig_manager.get_qconfig( activation_qat_qkwargs={"dtype": qint16}, activation_calibration_qkwargs={ "dtype": qint16, }, activation_calibration_observer="mix", ) m[3].qconfig = qconfig_manager.get_qconfig( activation_qat_qkwargs={"dtype": qint16}, activation_calibration_qkwargs={ "dtype": qint16, }, activation_calibration_observer="mix", ) m[-1].qconfig = qconfig_manager.get_default_qat_out_qconfig() self.reg_branches[-1][ -1 ].qconfig = qconfig_manager.get_default_qat_out_qconfig() self.query_embedding.qconfig = None int16_module = [ self.reference_points, self.sigmoid, ] for m in int16_module: m.qconfig = qconfig_manager.get_qconfig( activation_qat_qkwargs={"dtype": qint16}, activation_calibration_qkwargs={ "dtype": qint16, }, activation_calibration_observer="mix", )總結(jié)與建議訓(xùn)練建議
浮點(diǎn)和公版一致即可
qat 訓(xùn)練需要將 lr 降低,下降策略建議使用 StepDecayLrUpdater。
建議 bev size 的選擇考慮性能影響。征程 6 相比于 征程 5 帶寬增大,但仍需注意 bevsize 過大導(dǎo)致訪存時(shí)間過長對性能的影響,建議考慮實(shí)際部署情況選擇合適的 bevsize 做性能驗(yàn)證。
使用 bevmask 來提升運(yùn)行性能,可參考 4.1 章節(jié)使用 gridsample 替換不支持的 scatter。
在注意力機(jī)制中存在一些 ElementWise 操作,對于導(dǎo)致性能瓶頸的可以考慮 conv 替換,對于造成量化風(fēng)險(xiǎn)的可以根據(jù)敏感度分析結(jié)果合理選擇更高的量化精度,以確保注意力機(jī)制的部署。
本文通過對 Bevformer 在地平線征程 6 上量化部署的優(yōu)化,使得模型在該計(jì)算方案上用低于 1%的量化精度損失,得到 latency 為 45.74ms 的部署性能,同時(shí),通過 Bevformer 的部署經(jīng)驗(yàn),可以推廣到其他模型部署優(yōu)化,例如包含 MSDA 模型結(jié)構(gòu)、transformer-based BEV 的部署。
附錄論文:https://arxiv.org/pdf/2203.17270
公版代碼:
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請聯(lián)系工作人員刪除。