损失函数
损失函数衡量模型预测值与真实值之间的差距。训练的目标是最小化损失函数。
本项目实现了两个损失函数:
- MSELoss(均方误差)— 用于回归任务
- CrossEntropyLoss(交叉熵,内置 Softmax)— 用于分类任务
均方误差 (MSE)
数学定义
给定
更一般地,当预测和目标为多维张量时,
其中
梯度推导
对单个预测值
只有
写成向量形式:
代码实现
src/nn/losses/mseLoss.py:
python
def forward(self, predictions: np.ndarray, targets: np.ndarray) -> float:
difference = predictions - targets
loss = np.mean(difference**2)
# 缓存以供 backward 使用
self.predictions = predictions
self.targets = targets
self.elementCount = predictions.size
return float(loss)
def backward(self) -> np.ndarray:
# dL/dPred = 2 * (predictions - targets) / N
inputGradient = 2 * (self.predictions - self.targets) / self.elementCount
return inputGradient使用场景
- 回归任务:预测连续值(如 Sine 回归)
- 要求
predictions和targets形状完全一致 - 当训练器中
taskType="regression"时使用
交叉熵损失(Softmax + CrossEntropy)
为什么需要 Softmax
对于
性质:
(总和为 1) (始终为正) - 保持相对大小关系
交叉熵损失定义
其中
直觉:若模型对真实类别赋予了高概率(
数值稳定性技巧
直接计算
证明:分子分母同时乘以
最优雅的梯度
Softmax + 交叉熵的组合有一个令人惊讶的简洁梯度:
其中
推导概要:
对
代码实现
src/nn/losses/crossEntropyLoss.py:
python
def forward(self, logits: np.ndarray, targetLabels: np.ndarray) -> float:
# 数值稳定:减去每行最大值
shiftedLogits = logits - np.max(logits, axis=1, keepdims=True)
# Softmax
expLogits = np.exp(shiftedLogits)
probabilities = expLogits / np.sum(expLogits, axis=1, keepdims=True)
# 取真实类别的概率
selectedProbs = probabilities[np.arange(batchSize), targetLabels]
# 裁剪避免 log(0)
clippedProbs = np.clip(selectedProbs, self.epsilon, 1.0)
# 交叉熵损失
loss = -np.mean(np.log(clippedProbs))
# 缓存 softmax 概率用于 backward
self.probabilities = probabilities
self.targetLabels = targetLabels
self.batchSize = batchSize
return float(loss)
def backward(self) -> np.ndarray:
# dL/dz = (p - one_hot(y)) / N
inputGradient = self.probabilities.copy()
inputGradient[np.arange(self.batchSize), self.targetLabels] -= 1.0
inputGradient /= self.batchSize
return inputGradient关键实现细节:
shiftedLogits是数值稳定技巧的核心probabilities缓存了整个 softmax 输出矩阵,供backward使用backward中的-= 1.0实现了,其中 np.arange(batchSize)索引每个样本的真实类别位置epsilon参数(默认1e-12)防止
使用场景
- 分类任务(二分类或多分类)
targetLabels必须是整数索引(不是 one-hot)- 当训练器中
taskType="classification"时使用
损失函数对比
| 特性 | MSELoss | CrossEntropyLoss |
|---|---|---|
| 任务类型 | 回归 | 分类 |
| 输入 | 预测值(任意形状) | logits(2D) |
| 目标 | 真实值(同形状) | 整数类别索引(1D) |
| 数值技巧 | 无 | logit-shift + epsilon clipping |
| 梯度与输入同形状 | 是 | 是 |