今回は機械学習における重要概念の1つである「過学習(Overfitting)」について学習した。
これまでTitanicデータセットを用いてロジスティック回帰モデルを作成し、AccuracyやPrecision、Recallなどの評価指標を確認してきた。
今回はその続きとして、「モデルの性能はどのように評価するべきか」「学習データに強すぎるモデルは本当に優秀なのか」を考える回となった。
これまで作成したモデルの確認
まず現在のロジスティック回帰モデルについて、学習データとテストデータのAccuracyを比較した。
print("train")
print(model.score(X_train, y_train))
print("test")
print(model.score(X_test, y_test))
実行結果は以下の通り。
train 0.800561797752809 test 0.8044692737430168
model.score()とは何か
最初に疑問に思ったのは、model.score()が何を計算しているのかという点である。
確認してみると、分類モデルの場合はAccuracy(正解率)を返していることが分かった。
つまり以下のコードとほぼ同じ結果になる。
from sklearn.metrics import accuracy_score accuracy_score(y_train, model.predict(X_train)) accuracy_score(y_test, model.predict(X_test))
学習済みモデルに対して予測を行い、その結果と正解ラベルを比較して正解率を算出している。
学習データとテストデータの差を確認する
今回の結果では、trainとtestの差は約0.004であった。
train : 0.8006 test : 0.8045 差分 : 約0.004
この結果について自分なりに考察した。
- 学習データとテストデータでほぼ同じ精度が出ている
- 過学習の兆候は見られない
- Kaggleなどで0.01〜0.05の改善を狙う世界では無視できない差ではないが、異常な差とも言えない
ChatGPTとの壁打ちの中でも、この解釈で問題ないことを確認できた。
過学習とは何か
ここから本題である。
講師役のChatGPTから次のような問いを投げられた。
もし学習データを完全に暗記できるモデルがあったらどうなるか?
自分の回答は以下である。
- X_trainに対する予測を完全にy_trainへ一致させられる
- train accuracyは100%になる
- しかし未知データの特徴を学習しているわけではない
- そのためtest accuracyは低くなる可能性が高い
これは機械学習における過学習の本質そのものだった。
暗記と学習の違い
ここで印象に残った説明があった。
機械学習モデルが学ぶべきなのは、データそのものではなく「データの背後にある法則」である。
Titanicであれば例えば以下のような傾向である。
- 女性は生存率が高い
- 上級客室の乗客は生存率が高い
- 子供は比較的助かりやすい
これらは新しい乗客に対しても適用できる。
一方で、
- 乗客Aは生存
- 乗客Bは死亡
- 乗客Cは生存
といった個別データを記憶しても、新しい乗客には役に立たない。
これが「暗記」と「学習」の違いである。
なぜ過学習が問題なのか
初心者の頃は「Accuracyが高いほど良いモデル」と考えがちである。
しかし実際にはそう単純ではない。
例えば以下のような結果だったらどうだろうか。
Train Accuracy : 100% Test Accuracy : 68%
一見すると学習データでは完璧である。
しかし未知データに対しては大きく性能が落ちている。
これは実務では危険信号である。
本番環境で使うのは未知データだからだ。
今回の学び
今回の学習で理解できたことは以下である。
- model.score()は分類問題ではAccuracyを返す
- train/test両方を確認することが重要
- 高いAccuracyだけでは良いモデルとは言えない
- 未知データへの一般化性能が重要
- 過学習とは学習データを覚えすぎた状態である
これまでの学習では「モデルを作る」ことに意識が向いていた。
しかし今回の学習で、「作ったモデルをどう評価するか」が同じくらい重要だと実感した。
次回予告
次回はDecisionTreeClassifierを使って、実際に過学習を発生させる予定である。
ロジスティック回帰は比較的シンプルなモデルなので過学習しにくい。
一方で決定木は条件分岐を増やすことで学習データをかなり細かく記憶できる。
実際にtrain accuracyとtest accuracyがどのように変化するのかを確認しながら、過学習を数字で体感していきたい。