恶补Pytorch的基础
本篇blog以Pytorch为框架,实现一些常见的函数
实现softmax和cross-entropy
定义softmax操作
1 | def softmax(input): |
测试:softmax(torch.tensor([[3, 5, 5.2], [2, 3.5, 6]]))
,得到tensor([[0.0574, 0.4243, 0.5183],[0.0166, 0.0746, 0.9088]])
1 | def cross_entropy(y_hat, y): |
7.6 updated
还可以通过torch的索引操作直接读取元素1
2def cross_entropy2(y_hat, y):
return -torch.log(y_hat[range(len(y_hat)), y])
所以对于网络输出得到的logitstorch.tensor([[3, 5, 5.2], [2, 3.5, 6]])
,只需要先输入softmax得到概率分布,再进行cross_entropy得到最终的loss1
2
3logits = torch.tensor([[3, 5, 5.2], [2, 3.5, 6]])
label = torch.tensor([1, 2])
print(softmax(cross_entropy(softmax(logits), label)))
如果是以one-hot形式给出来的标签,则可以通过softmax获取概率后直接使用矩阵按元素乘后再取sum的方式获取结果(商汤二面)1
2
3
4
5
6
7logits = torch.tensor([[3, 5, 5.2], [2, 3.5, 6]])
label = torch.tensor([[1, 0, 0], [0, 0, 1]])
loss = softmax(logits) * label
# tensor([[0.0574, 0.0000, 0.0000],
# [0.0000, 0.0000, 0.9088]])
print(loss.sum(dim=1))
# tensor([0.0574, 0.9088])