用 PyTorch 实现常见函数

恶补Pytorch的基础
本篇blog以Pytorch为框架,实现一些常见的函数

实现softmax和cross-entropy

定义softmax操作

1
2
3
4
def softmax(input):
i_exp = input().exp() # 求出其指数值
partition = i_exp.sum(dim=1, keepdim=True) # 把每一行的指数值求和
return i_exp / partition # 将每个指数值除以求和的结果

测试:
softmax(torch.tensor([[3, 5, 5.2], [2, 3.5, 6]])),得到tensor([[0.0574, 0.4243, 0.5183],[0.0166, 0.0746, 0.9088]])

1
2
3
def cross_entropy(y_hat, y):
value = y_hat.gather(1, y.reshape(-1, 1))
return -torch.log(value)

7.6 updated
还可以通过torch的索引操作直接读取元素

1
2
def cross_entropy2(y_hat, y):
return -torch.log(y_hat[range(len(y_hat)), y])

所以对于网络输出得到的logitstorch.tensor([[3, 5, 5.2], [2, 3.5, 6]]),只需要先输入softmax得到概率分布,再进行cross_entropy得到最终的loss

1
2
3
logits = torch.tensor([[3, 5, 5.2], [2, 3.5, 6]])
label = torch.tensor([1, 2])
print(softmax(cross_entropy(softmax(logits), label)))

如果是以one-hot形式给出来的标签,则可以通过softmax获取概率后直接使用矩阵按元素乘后再取sum的方式获取结果(商汤二面)

1
2
3
4
5
6
7
logits = torch.tensor([[3, 5, 5.2], [2, 3.5, 6]])
label = torch.tensor([[1, 0, 0], [0, 0, 1]])
loss = softmax(logits) * label
# tensor([[0.0574, 0.0000, 0.0000],
# [0.0000, 0.0000, 0.9088]])
print(loss.sum(dim=1))
# tensor([0.0574, 0.9088])