テクめも

プログラミング関連のちょっとしたTipsなどを書いています。

PyTorchでGlobal Max PoolingとGlobal Average Poolingを使う方法

はじめに

Global Max PoolingやGlobal Average Poolingを使いたいとき、KerasではGlobalAveragePooling1Dなどを用いると簡単に使うことができますが、PyTorchではそのままの関数はありません。

そこで、PyTorchでは、Global Max PoolingやGlobal Average Poolingを用いる方法を紹介します。

Poolingについては以下の記事を読むとイメージがつきやすいです。

用いるデータ

デモのため以下のような6 x 6のデータを使いたいと思います。

import torch
x = torch.arange(0, 36, dtype=torch.float32).view(1, 6, 6)
tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [12., 13., 14., 15., 16., 17.],
         [18., 19., 20., 21., 22., 23.],
         [24., 25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34., 35.]]])

なお、dtypeを指定しないとintになるので、functionを使うことができません。

Global Max Pooling

PyTorchのフォーラムでも同様の質問がありました。

回答としては、普通のmax_pool2dを用いて実装する方法です。

import torch.nn.functional as F
F.max_pool2d(x, kernel_size=x.size()[2:])

実際に試してみると、出力は

tensor([[[35.]]])

となります。

ちなみに、x.size()[2:]6となります。

また、adaptive_max_pool2dを使っても同じことできます。

F.adaptive_max_pool2d(x, (1, 1))

Global Average Pooling

Global Average PoolingもGlobal Max Poolingと同様です。

F.avg_pool2d(x, kernel_size=x.size()[2:])
F.adaptive_avg_pool2d(x, (1, 1))

出力はともに

tensor([[[17.5000]]])

となります。

余談

ちなみに、普通のMax PoolingとAverage Poolingも同じデータで出力を見てみました。

Max Pooling

F.max_pool2d(x, kernel_size=2, stride=2)
tensor([[[ 7.,  9., 11.],
         [19., 21., 23.],
         [31., 33., 35.]]])

なお、strideは省略するとkernel_sizeと同じ値になるので

F.max_pool2d(x, 2)

と同じです。

Average Pooling

F.avg_pool2d(x, kernel_size=2, stride=2)
tensor([[[ 3.5000,  5.5000,  7.5000],
         [15.5000, 17.5000, 19.5000],
         [27.5000, 29.5000, 31.5000]]])