PyTorchでCT画像分類AIを作ってみよう|Dataset・CNN・学習ループを一気通貫で解説

医療画像AIシリーズ第3弾。前回構築した前処理パイプラインをPyTorchのDatasetに組み込み、CNNで「病変あり/なし」を分類するモデルを学習させます。学習ループ・評価・モデル保存まで、動くコード付きで一気通貫に解説します。

この記事でつくるもの ── シリーズのゴール

このシリーズでは、医療画像AIをゼロから自分の手で動かすことを目指してきました。第1弾でDICOM画像の読み込み、第2弾で前処理パイプラインの構築を学び、今回はいよいよ最終目標だった「CNNで画像分類モデルを学習させる」段階に入ります。

⚠ この記事は「中級〜上級編」です

シリーズ第3弾となる本記事は、第1弾 「PythonでDICOM画像を読み込んでみよう|pydicom入門」、第2弾 「PythonでCT画像の前処理を自動化しよう」 を読んでいることに加えて、Pythonの基礎文法とPyTorchの基本(クラスによるモデル定義・学習ループの型)を前提に進みます。Python自体がこれからの方は無料のPython初学者コース、PyTorchが初めての方はE資格コーディング対策コース(無料・全12レッスン)から始めるのがおすすめです。

💡 「流れだけつかみたい」方も大丈夫です

コードを1行ずつ完全に理解できなくても心配いりません。本記事は「医療画像AIはこういう手順で作られていくのか」という全体の流れが眺めるだけでもつかめるように、各ステップの役割を日本語で補足しながら書いています。まずは通して読み、興味が湧いたところでコードを動かしてみてください。

完成形の全体像は次のとおりです。前回までに作った前処理が、そのまま学習パイプラインの入口になります。

# 医療画像AI 学習パイプラインの全体像 # # DICOM読み込み(第1弾) # ↓ # 前処理パイプライン(第2弾: HU変換〜正規化) # ↓ # Dataset / DataLoader(今回: データを束ねる) # ↓ # CNNモデル(今回: 特徴を学習する) # ↓ # 学習ループ → 評価 → モデル保存(今回)

読み終える頃には、「前処理済みの画像とラベルさえ用意すれば、自力で分類AIを学習・評価できる」状態になっているはずです。

準備:ライブラリのインストールと前提

今回の主役は深層学習フレームワークの PyTorch です。CPUだけでも今回のコードは十分動きます(数十秒程度)。

ターミナル
$ pip install torch numpy

PyTorchを初めて触る方は、モデル定義(nn.Module)や学習ループの「型」を先に眺めておくと理解がスムーズです。当サイトのE資格コーディング対策コース(無料・全12レッスン)では、ブラウザ上でこの型を1つずつ読み解けます。

学習データを用意する(まずは合成データで)

本物の症例データは個人情報や倫理審査の壁があり、練習用に気軽には使えません。そこで今回は、「結節(しこり状の病変)あり/なし」を模した合成画像を作って学習の流れを体験します。学習コードの流れ自体は、合成データでも実データでもほぼ同じです。実務では画像のリストを前処理済みCTに差し替えれば動きますが、データ拡張やクラス不均衡への対応、後述する患者単位のデータ分割など、実データならではの追加の工夫が必要になります。

Python
import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader def make_synthetic_ct(n_per_class=100, size=64, seed=0): """「結節あり/なし」を模した2クラスの合成画像を作る""" rng = np.random.default_rng(seed) images, labels = [], [] for label in (0, 1): for _ in range(n_per_class): img = rng.normal(0.3, 0.05, (size, size)) # 背景(軟部組織を模す) if label == 1: # 「結節あり」クラス cy, cx = rng.integers(16, size - 16, 2) r = int(rng.integers(4, 9)) y, x = np.ogrid[:size, :size] img[(y - cy) ** 2 + (x - cx) ** 2 <= r ** 2] += 0.4 images.append(np.clip(img, 0, 1).astype(np.float32)) labels.append(label) return images, labels images, labels = make_synthetic_ct(n_per_class=100) print(f"画像数: {len(images)} 1枚の形: {images[0].shape}") # 画像数: 200 1枚の形: (64, 64)

ポイントは、画像を「0〜1に正規化されたfloat32のNumPy配列」に揃えていることです。これは第2弾の前処理パイプライン preprocess_ct() の出力形式と同じです。つまり実データを使うときは、次のように置き換えるだけです。

Python(実データを使う場合の例)
# 実務では: フォルダごとに分けたDICOMを前処理して同じ形式のリストを作る # images = [preprocess_ct(pydicom.dcmread(f), target_size=(64, 64)) for f in ファイル一覧] # labels = [0 or 1 をファイルの所属フォルダから付与]

💡 target_size をモデルに合わせる

第2弾の preprocess_ct() はデフォルトで target_size=(256, 256) にリサイズします。今回のCNNは64×64入力を前提に設計しているので、上の例のように target_size=(64, 64) を指定してサイズを揃えてください。256×256のまま使う場合は、後述する全結合層の入力サイズを 32×64×64 に変更します(256 → 1回目のプーリングで128 → 2回目で64、と半分ずつ縮むため、最終的な空間サイズが64×64になります)。

⚠ 実データを扱うときの注意

患者データを学習に使う場合は、所属機関の倫理審査・匿名化・データ利用規約に必ず従ってください。本記事の合成データは、あくまで技術習得のための練習用です。

Datasetクラスに前処理を組み込む

PyTorchでは、データの供給を Dataset というクラスに任せます。役割はシンプルで、「何枚あるか(__len__)」と「i番目のデータをください(__getitem__)」に答えられるようにするだけです。

Python
class CTDataset(Dataset): def __init__(self, images, labels): self.images = images # 前処理済み画像(0〜1のnumpy配列)のリスト self.labels = labels # クラス番号のリスト def __len__(self): return len(self.images) def __getitem__(self, idx): img = torch.from_numpy(self.images[idx]).unsqueeze(0) # (H,W) → (1,H,W) label = torch.tensor(self.labels[idx], dtype=torch.long) return img, label

2つだけ、つまずきやすいポイントを補足します。

DataLoaderでミニバッチ化する

Dataset が「1枚ずつ取り出す係」だとすれば、DataLoader「まとめて配る係」です。学習用と検証用にデータを分けてから、それぞれをDataLoaderに渡します。

Python
torch.manual_seed(0) # 再現性のため乱数を固定 dataset = CTDataset(images, labels) n_val = int(len(dataset) * 0.2) n_train = len(dataset) - n_val train_set, val_set = torch.utils.data.random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train_set, batch_size=16, shuffle=True) val_loader = DataLoader(val_set, batch_size=16, shuffle=False) xb, yb = next(iter(train_loader)) print(f"1バッチの形: {xb.shape}, ラベル: {yb.shape}") # 1バッチの形: torch.Size([16, 1, 64, 64]), ラベル: torch.Size([16])

ここで大事なのが shuffle の使い分けです。学習用は shuffle=True(毎エポックで順番を混ぜて学習の偏りを防ぐ)が定石です。いっぽう検証では、データの順番は評価結果に影響しないためシャッフルする必要がなく、shuffle=False にしておくと結果の確認・再現もしやすくなります。バッチの形が「(バッチ16, チャネル1, 高さ64, 幅64)」になっていること、これがCNNに入る最終形です。

⚠ 実際の医療AIでは「患者単位」で分割します

本記事では簡単化のため random_split で画像をランダムに分割しています。しかし実データで、同一患者の画像や同一検査のスライスが学習用と検証用の両方に混ざると、答えを知っている状態で試験を受けるようなもの(データリーク)になり、性能を大きく過大評価してしまいます。実務では患者ID単位で train / validation / test を分割するのが鉄則です。

CNNモデルを定義する

いよいよモデル本体です。CNN(畳み込みニューラルネットワーク)は「畳み込みで特徴を抽出 → プーリングで縮小」を繰り返し、最後に全結合層で分類します。今回は小さく確実に動く2段構成にしました。

Python
class SimpleCNN(nn.Module): def __init__(self, num_classes=2): super().__init__() self.features = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 64 → 32 nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 32 → 16 ) self.classifier = nn.Linear(32 * 16 * 16, num_classes) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) # (N, 32, 16, 16) → (N, 8192) return self.classifier(x) model = SimpleCNN(num_classes=2) print(model(xb).shape) # torch.Size([16, 2]) ← バッチ16枚 × 2クラス分のスコア

全結合層の入力サイズ「32×16×16」は暗記ではなく計算で出せます。kernel_size=3, padding=1 の畳み込みは画像サイズを変えず、MaxPool2d(2) が半分にする——なので 64→32→16 と縮小され、最終チャネル数32を掛けて 32×16×16=8192 です。この「サイズを追う力」はE資格のコーディング問題でも中心的に問われるスキルです。

💡 出力層に活性化関数を付けないのはなぜ?

モデルの出力は「ロジット」(softmax前の生スコア)のままにします。次に使う CrossEntropyLoss が内部でsoftmax相当の処理を行うため、モデル側でsoftmaxを掛けると二重適用になり学習がうまく進みません。地味ですが実務でも頻発するミスです。

学習ループを回す

PyTorchの学習ループには決まった「型」があります。①勾配をリセット → ②順伝播 → ③損失計算 → ④逆伝播 → ⑤パラメータ更新。この5拍子を体で覚えてしまいましょう。

Python
criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for epoch in range(5): model.train() total_loss = 0.0 for xb, yb in train_loader: optimizer.zero_grad() # ① 前回の勾配をリセット out = model(xb) # ② 順伝播 loss = criterion(out, yb) # ③ 損失計算 loss.backward() # ④ 逆伝播(勾配計算) optimizer.step() # ⑤ パラメータ更新 total_loss += loss.item() print(f"epoch {epoch+1}: 平均loss = {total_loss / len(train_loader):.4f}") # epoch 1: 平均loss = 0.7294 # epoch 2: 平均loss = 0.6368 # epoch 3: 平均loss = 0.5458 # epoch 4: 平均loss = 0.4150 # epoch 5: 平均loss = 0.2783

エポックを重ねるごとに損失が下がっていれば、モデルは学習を進められています。ただし、ここで見ているのは学習データの損失だけである点に注意してください。学習データの損失だけが下がり続け、検証データでの性能が悪化していく場合は過学習(オーバーフィッティング)のサインです。実務では学習・検証の両方の損失を監視します。逆に、損失がまったく下がらないときにまず疑うのは「zero_grad() の書き忘れ(勾配が溜まり続ける)」や「step() の書き忘れ(パラメータが一切更新されない)」。エラーが出ないぶん気づきにくい、代表的な落とし穴です。

また、損失の集計に loss.item() を使っている点にも注目してください。Tensorのまま足し込むと計算グラフを抱え込んでメモリを圧迫します。数値の記録には .item()、が鉄則です。

評価する ── 医療AIでは「再現率」に注目

学習が終わったら、モデルが見たことのない検証データで性能を確かめます。評価時のお作法は2つ。model.eval()(DropoutやBatchNormを評価モードに切り替える)と、torch.no_grad()(勾配計算を止めて高速・省メモリにする)を必ずセットで使います。今回のモデルにはDropoutもBatchNormも入っていないため eval() で挙動は変わりませんが、実際のCNNではどちらも多用されるので、評価の前には常にこの2つを書く癖をつけておきましょう。

Python
model.eval() tp = fp = fn = tn = 0 correct = 0 with torch.no_grad(): for xb, yb in val_loader: pred = model(xb).argmax(dim=1) # スコア最大のクラスを予測とする correct += (pred == yb).sum().item() tp += ((pred == 1) & (yb == 1)).sum().item() fp += ((pred == 1) & (yb == 0)).sum().item() fn += ((pred == 0) & (yb == 1)).sum().item() tn += ((pred == 0) & (yb == 0)).sum().item() accuracy = correct / len(val_set) # 分母が0になるケース(該当クラスが1枚もない等)に備えてガードを入れる recall = tp / (tp + fn) if tp + fn > 0 else 0 # 再現率: 実際の「あり」をどれだけ拾えたか precision = tp / (tp + fp) if tp + fp > 0 else 0 # 適合率: 「あり」と予測した中の正解率 print(f"正解率: {accuracy:.3f} 再現率: {recall:.3f} 適合率: {precision:.3f}") # 正解率: 0.950 再現率: 0.913 適合率: 1.000

正解率95%——なかなか良さそうに見えます。しかし医療AIでは、正解率だけを見てはいけません。今回の結果を分解すると、「実際に結節がある23枚のうち2枚を『なし』と判定」しています(再現率91.3%)。検診のような「見逃しが致命的」な場面では、この再現率(Recall)こそが最重要指標になります。

💡 再現率と適合率の使い分け

見逃しを減らしたい(がん検診など)→ 再現率を重視。誤検知による余計な精査を減らしたい → 適合率を重視。どちらを優先するかは「間違いのコスト」で決まります。この考え方はG検定・E資格の両方で問われる重要ポイントです。なお本記事では再現率を例にしましたが、実務では混同行列(TP/FP/FN/TNの内訳表)を眺めつつ、F1スコアやROC-AUCなどの指標も合わせて総合的に評価します。

モデルの保存と読み込み

学習したモデルは state_dict()(重みの辞書)として保存するのが定石です。モデル全体を丸ごと保存するより安全で、環境をまたいだ移植もしやすくなります。

Python
# 保存(重みだけを辞書として書き出す) torch.save(model.state_dict(), "ct_cnn.pth") # 読み込み(同じ構造のモデルを作ってから重みを流し込む) model2 = SimpleCNN(num_classes=2) model2.load_state_dict(torch.load("ct_cnn.pth")) model2.eval() # 推論に使う前に評価モードへ

これで「学習 → 評価 → 保存 → 再利用」まで、医療画像AI開発の一連のサイクルが手元で回せるようになりました。

まとめ ── 次のステップへ

シリーズ3記事を通して、DICOMの読み込みから分類モデルの学習・評価までを一気通貫で実装しました。今回のポイントを振り返ります。

ここから先は、転移学習(学習済みモデルの再利用)やデータ拡張で精度を高めたり、セグメンテーション(病変の輪郭抽出)に進んだりと、世界が一気に広がります。その土台になるのが、今回使ったPyTorchの「型」を正確に読み書きできる力です。

PyTorchの「型」を、試験レベルまで固めよう

本記事で使った学習ループ・shape計算・評価指標は、E資格コーディング試験の核心でもあります。無料コースで読解力を鍛え、実践模試で本番形式の総仕上げを。

無料のE資格コーディング対策コース E資格 実践模試 コーディング特化版を見る

医療AIナビ 運営者

「医療×AI」を専門とする現役AIエンジニア。非専門家からAI開発に参入した経験をもとに、医療AIの最新情報やAI資格対策を発信しています。E資格・G検定・Generative AI Test合格済み。