量化训练中fusebn与withbn核心差异解析
摘要
1 BN 简介 Batch Normalization(BN),说白了就是深度学习中常用的一种归一化手段。它的核心操
1.BN 简介
Batch Normalization(BN),说白了就是深度学习中常用的一种归一化手段。它的核心操作很直接:对每个 batch 的特征做 \((x-\mu)/\sigma\) 归一化——\(\mu\) 是均值,\(\sigma\) 是方差——然后用可学习的 \(\gamma\)、\(\beta\) 调整一下分布。
训练过程中,BN 会通过滑动更新的方式维护全局的 running_mean 和 running_var。在 PyTorch 里,默认的更新规则长这样:
running_mean = (1 - momentum) * running_mean + momentum * batch_mean
running_var = (1 - momentum) * running_var + momentum * batch_var
你看,这里的 momentum 默认是 0.1:
nn.BatchNorm2d(num_features, momentum=0.1)
也就是说,它控制着全局统计量的更新速度。推理的时候,BN 直接用这些预存的统计值,不再依赖 batch 数据。所以推理阶段的 BN 公式就变成了:
2. freezebn 简介
freezebn,就是训练时把 BN 层“冻住”——通常冻结 running_mean/running_var 的更新,同时固定 \(\gamma/\beta\)(也有只冻结统计量的情况)。这样一来,BN 层就以“推理模式”运行了。
但 freezebn 不是所有浮点训练场景都用得上,主要是在迁移学习/微调(Fine-tuning)和小批量训练这些场景里才会派上用场。
2.1 迁移学习 / 微调
在迁移学习或微调时用 freezebn,最核心的目的就一个:别让预训练模型学到的 BN 统计量被搞坏了。这大概是 freezebn 最常见的用法。
- 预训练模型的 BN 层已经在一堆通用数据(比如 ImageNet)上把 running_mean/running_var 训练得挺稳了,这些统计量是模型泛化能力的重要基础;
- 微调用的新数据集往往小得多,而且分布和预训练集差得远(比如从通用图像转到人脸识别、工业缺陷检测)。小批量数据的 batch_mean/batch_var 很容易偏离预训练的全局统计量;
- 如果不冻结 BN,微调时新数据会不断更新 running_mean/running_var,预训练积累的有效信息就会被覆盖,模型快速“忘掉”通用特征,最后过拟合或者精度暴跌;
- 冻结 BN 后,BN 层直接用预训练的全局统计量,只更新后面的分类层/任务层参数,既保留了通用特征,又能适配新任务。
2.2 小批量训练
浮点训练时 batch size 如果太小(比如 ≤8),freezebn 能解决 BN 层统计量不稳定的问题:
- BN 层算出来的 batch_mean/batch_var 会和真实全局统计量差很多——批次越小,随机性越强,归一化效果就失效了,训练过程容易震荡甚至不收敛;
- 冻结 BN 后,直接用预训练的稳定统计量,能绕开小批量带来的统计噪声,训练过程自然更稳(迁移学习里也常用这个思路)。
2.3 freezebn 示例
import torch
import torch.nn as nn
import torchvision.models as models
def freeze_bn(model):
"""
冻结模型中所有BN层:
1. 设置eval(),让BN以推理模式运行(不更新统计量)
2. 冻结BN层的参数(γ/β不更新)
"""
# 实现细节略...
好了,回到量化部署上。常听人说起:浮点训练时用 freezebn 的技巧,QAT 阶段建议试试 withbn 的训练方式;当然,常规情况下还是推荐 QAT 采用 fusebn。下面我们具体看看量化部署时,fusebn 和 withbn 到底是怎么回事。
3. conv 与 bn 融合原理
地平线算法工具链在做量化训练 prepare 时,默认会把 Conv(卷积)和 BN(批归一化)融合。为什么?因为 BN 层会动态改变张量的数值范围,融合之后就能把这个动态变化“固化”成卷积的静态参数,消除它对量化的干扰,让量化的“范围映射”更准,从而提升量化模型的精度。
分开看,单独的 Conv 卷积分和 BN 计算可以拆成下面几步:
- 卷积层输出:
- BN 层输出(训练时):

- 把 BN “固化”到卷积参数里:用 BN 的全局统计量(\(\mu/\sigma^2\)),把 BN 的计算“合并”到卷积层里。

融合之后,Conv+BN 就等价于一个“新的卷积层”,不再需要单独的 BN 计算,推理效率自然就上去了。
4. fusebn/withbn 介绍
qat_mode 用来设置在 QAT(量化感知训练)阶段是否带 BN 进行量化训练。如果在浮点训练中用了 freeze bn 的技巧,那么 QAT 训练时 qat_mode 就要设成 withbn。
qat_mode 有三种可选设置:
class QATMode(object):
FuseBN = "fuse_bn"
WithBN = "with_bn"
WithBNReverseFold = "with_bn_reverse_fold" # 先不用关注
4.1 fuse_bn
QAT 阶段不带 BN,这是 horizon_plugin_pytorch 默认的量化训练方式。
把 qat_mode 设成 fuse_bn 后,浮点模型在做 op 融合时,BN 的 weight 和 bias 都会被吸收进 Conv 的 weight 和 bias 里,原来 Conv + BN 的组合就只剩下 Conv 了。这个吸收过程理论上没有误差。
4.2 with_bn
QAT 阶段带 BN 训练。
把 qat_mode 设成 with_bn 后,浮点模型转成 QAT 模型时,BN 不会吸收进 Conv。QAT 阶段 Conv + BN + 输出量化节点 会作为一个被融合的量化 op 存在。量化训练结束 convert 转成 quantized 模型时,BN 的 weight 和 bias 才会自动吸收进 conv 的量化参数中。吸收之后得到的 quantized op 和原来的 QAT op 计算结果一致。
至于为什么说“理论上吸收前后无损”或“无变化”呢?因为实际计算中,吸收前后两次浮点计算结果在少数情况下可能在小数点很靠后的数位上不一致。这种微小的差异加上量化操作,有可能导致吸收 BN 后 Conv 的输出和吸收前 Conv + BN 的输出在部分数值上产生一个输出 scale 的绝对误差。
4.3 使用用法(重要)
记得在 prepare 之前设置 qat_mode。calib 和 QAT 的 prepare 前必须保持一致,否则会出现 qat_model 无法加载 calib_model_ckpt 的问题。
from horizon_plugin_pytorch.qat_mode import QATMode, set_qat_mode
set_qat_mode(QATMode.WithBN)
calib_qat_net = prepare(float_model, (input_tensor),
qconfig_setter=qconfig_setter)
通常情况下,训练流程是先浮点训练到理想精度,再做量化训练。这时候直接用 fuse_bn 就足够了。
来源:互联网
本网站新闻资讯均来自公开渠道,力求准确但不保证绝对无误,内容观点仅代表作者本人,与本站无关。若涉及侵权,请联系我们处理。本站保留对声明的修改权,最终解释权归本站所有。