はじめに
Global Max PoolingやGlobal Average Poolingを使いたいとき、KerasではGlobalAveragePooling1Dなどを用いると簡単に使うことができますが、PyTorchではそのままの関数はありません。
そこで、PyTorchでは、Global Max PoolingやGlobal Average Poolingを用いる方法を紹介します。
Poolingについては以下の記事を読むとイメージがつきやすいです。
用いるデータ
デモのため以下のような6 x 6
のデータを使いたいと思います。
import torch x = torch.arange(0, 36, dtype=torch.float32).view(1, 6, 6)
tensor([[[ 0., 1., 2., 3., 4., 5.], [ 6., 7., 8., 9., 10., 11.], [12., 13., 14., 15., 16., 17.], [18., 19., 20., 21., 22., 23.], [24., 25., 26., 27., 28., 29.], [30., 31., 32., 33., 34., 35.]]])
なお、dtype
を指定しないとint
になるので、functionを使うことができません。
Global Max Pooling
PyTorchのフォーラムでも同様の質問がありました。
回答としては、普通のmax_pool2d
を用いて実装する方法です。
import torch.nn.functional as F F.max_pool2d(x, kernel_size=x.size()[2:])
実際に試してみると、出力は
tensor([[[35.]]])
となります。
ちなみに、x.size()[2:]
は6
となります。
また、adaptive_max_pool2d
を使っても同じことできます。
F.adaptive_max_pool2d(x, (1, 1))
Global Average Pooling
Global Average PoolingもGlobal Max Poolingと同様です。
F.avg_pool2d(x, kernel_size=x.size()[2:])
F.adaptive_avg_pool2d(x, (1, 1))
出力はともに
tensor([[[17.5000]]])
となります。
余談
ちなみに、普通のMax PoolingとAverage Poolingも同じデータで出力を見てみました。
Max Pooling
F.max_pool2d(x, kernel_size=2, stride=2)
tensor([[[ 7., 9., 11.], [19., 21., 23.], [31., 33., 35.]]])
なお、stride
は省略するとkernel_size
と同じ値になるので
F.max_pool2d(x, 2)
と同じです。
Average Pooling
F.avg_pool2d(x, kernel_size=2, stride=2)
tensor([[[ 3.5000, 5.5000, 7.5000], [15.5000, 17.5000, 19.5000], [27.5000, 29.5000, 31.5000]]])