【Python】オフラインでfine-tuningのpretrainedを使用する方法【PyTorch】
ネットワーク環境にないGPU計算機でfine-tuningのpretrainedをする方法についてまとめます。 今回はresnet18を使用するケースを考えていますが、同じ手順で他のモデルも同様にオフライン実行できると思います。
目次
環境
fine-tuningのpretrainedとは
pytorchでfine-tuningするときmodels.nameでモデルの型を参照し、pretrained=Trueでパラメータを付与する処理になっています。
model_ft = models.resnet18(pretrained=True)
このパラメータが保存されているオブジェクトファイルはライブラリが所持しているものではなく、オンラインからダウンロードする処理になっています。
因みにpretrained=Falseにすればオフライン実行できますが、パラメータは1からの学習になりfine-tuningの強みをあまり活かせません。
torchvision.modelsのオンライン参照箇所
model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', } def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model
変更する処理
1. 使用するモデルの学習済みモデルをダウンロードし、ローカルに保存する
ダウンロード先は上のmodel_urlsに書いてあります。 例えば、resnet18を使用したい場合は https://download.pytorch.org/models/resnet18-5c106cde.pth からダウンロードします。
2. pretrained=Falseでモデルの型のみ取得する
model_ft = models.resnet18(pretrained=False)
3. 上で定義したモデルに保存した学習済みモデルをloadする
ここでは保存先フォルダを model_pretrained にしています。
model_ft.load_state_dict(torch.load('model_pretrained/resnet18-5c106cde.pth'))
最後に
本来はライブラリ関数のオーバーライドとかで綺麗に実装したかったんですけど、色々やって断念しました。
もしオーバーライドで出来た方がいればコメントで教えてもらえると助かります🙏