tmori’s blog

公開メモ帳くらいの感覚で書いています。技術系多め。日常少なめ。

【Python】オフラインでfine-tuningのpretrainedを使用する方法【PyTorch】

ネットワーク環境にないGPU計算機でfine-tuningのpretrainedをする方法についてまとめます。 今回はresnet18を使用するケースを考えていますが、同じ手順で他のモデルも同様にオフライン実行できると思います。

目次

環境

fine-tuningのpretrainedとは

pytorchでfine-tuningするときmodels.nameでモデルの型を参照し、pretrained=Trueでパラメータを付与する処理になっています。

 model_ft = models.resnet18(pretrained=True)

このパラメータが保存されているオブジェクトファイルはライブラリが所持しているものではなく、オンラインからダウンロードする処理になっています。

因みにpretrained=Falseにすればオフライン実行できますが、パラメータは1からの学習になりfine-tuningの強みをあまり活かせません。

pytorch.org

torchvision.modelsのオンライン参照箇所

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

変更する処理

1. 使用するモデルの学習済みモデルをダウンロードし、ローカルに保存する

ダウンロード先は上のmodel_urlsに書いてあります。 例えば、resnet18を使用したい場合は https://download.pytorch.org/models/resnet18-5c106cde.pth からダウンロードします。

2. pretrained=Falseでモデルの型のみ取得する

model_ft = models.resnet18(pretrained=False)

3. 上で定義したモデルに保存した学習済みモデルをloadする

ここでは保存先フォルダを model_pretrained にしています。

 model_ft.load_state_dict(torch.load('model_pretrained/resnet18-5c106cde.pth'))

pytorch.org

最後に

本来はライブラリ関数のオーバーライドとかで綺麗に実装したかったんですけど、色々やって断念しました。

もしオーバーライドで出来た方がいればコメントで教えてもらえると助かります🙏