AIを用いた機械学習を学ぶ上で、必要な知識の一つに「過学習」があります。
機械学習において過学習が起こると、データ処理がうまく行われず、正確な結果を得ることができず、実運用が難しくなります。
過学習を理解し、対処方法を知っておくことは機械学習を学ぶ上で非常に重要です。
本記事では下記のポイントについてわかりやすく解説していきます。
- 過学習とは?
- 過学習が発生する主な原因
- 過学習の判断基準
- 過学習を防ぐ方法
過学習とは?
機械学習において陥りやすいトラブルの一つに過学習があります。
過学習について理解を深める前にまずは、機械学習について以下のポイントを抑えておきましょう。
- 機械学習とはデータを分析する手法の1つ
- データからコンピュータが自動で学習し、ルールやパターンを発見する方法
- 発見したパターンを未知のデータに当てはめてAI予測モデルを構築
この機械学習に使うデータをコンピュータが学習しすぎた結果、訓練データと過剰に適合してしまい、未知のデータ(新しいデータ)に対する精度が低下している状態のことを「過学習」と言います。
過学習が起こると、訓練データ上では正解率が高いにも関わらず、テストデータでは正解率が低いという状態に陥ってしまいます。
AIなどの機械学習では過学習を起こさず、あらゆるインプットに対して正しいアウトプットを行うことができるかが重要です。
過学習の具体例
過学習の具体例として「特定のクラスのテストの点数分布から学年全体の点数分布を予測するモデル」を取り上げます。
国語と算数のテストの結果を縦軸と横軸とし、実際の点数を記入して散布図を作成。コンピュータの学習予測パターンを曲線として書き込みます。
この場合、適正な予測ができているモデルは全体の分布をバランスよくとらえた緩やかな曲線になります。
しかし、過学習が起きている場合はデータが過剰に適合してしまうため、細かく湾曲する曲線ができてしまうのです。
下図の3つのグラフを見ましょう。左から「学習不足」「適切」「過学習」の順に並んでいます。
適切なモデルと比較すると、学習不足の場合は訓練データの特徴を捉えているだけでモデルの複雑さがありません。
過学習においては、訓練データに過剰適合してしまい細かく湾曲していることが分かります。
引用:How to avoid Over-fitting using Regularization?
過学習の問題点
過学習が生じると、訓練データに対しては高い精度を示すことができます。
一方で、新しいデータ(未知のデータ)に対しては低い精度で予測を示す可能性が高くなります。
モデルが訓練データのノイズやランダムな変動まで学習してしまい、真のトレンドやパターンを学習することができず、特定のデータに固有の特徴を学習してしまった結果となります。
つまり、データ一つひとつが持つ偏りに予測モデルが適合してしまい、データ全体としての傾向があいまいな状態です。
ビジネスの場面で過学習したモデルを用いて戦略を立てることのリスクとして、目的を果たせないどころか、間違った意思決定をしてしまいます。
結果として、企業のブランドイメージに悪影響を及ぼす可能性もあるため注意が必要です。
そのため、検証の段階で過学習が生じていないかを評価し、対策をすることが重要となります。
過学習が発生する主な原因
では、過学習が発生する主な原因にはどのようなことが挙げられるのでしょうか。
主な原因としては次の3つが挙げられます。
- データが不足しているから
- データが偏っているから
- モデルが複雑だから
データが不足しているから
過学習の原因の一つに訓練データが不足しているケースが挙げられます。
アルゴリズムは優れたものができていても、データが不足しているとAIは正しい学習を行うことができません。
データが不足しているために、モデルはデータ内の小さな変動やノイズに敏感に反応してしまうことが原因です。
AIにとって学習で得られるデータがすべてであるため、十分なデータ量を確保しAIに提供する必要があります。
特に複雑なモデルを使用する場合はデータ不足による過学習が起きやすくなるため注意が必要です。
データが偏っているから
機械学習は与えられたデータからしか分析を行うことはできません。
そのため、十分な量の偏りのないデータを学習させる必要があります。
偏ったデータばかりを学習させてしまうと、客観性に欠けた偏った分析や偏った予測しかできません。
精度の高い予測結果を得るためには、様々な特徴のある大量のデータを用意し学習させることが重要です。
モデルが複雑だから
さまざまなアルゴリズムを利用してモデルを構築すると、高度な分析が行えるようになるメリットがある一方で、過学習のリスクが増大する可能性があります。
複雑なモデルの場合、多くのパラメータを持ち、訓練データの細かい特徴まで捉えることができますが、本質的なパターンだけではなくノイズにまで適合してしまいます。
場合によっては、モデルの複雑さを調整することが必要です。
過学習の判断基準は?
では、過学習に陥っているかはどのように判断したらよいのでしょうか。
ここからは、過学習の判断基準について以下の4つの点から見てみましょう。
- データの種類を区別する
- ホールドアウト法で検証する
- 交差検証法で検証する
- 学習曲線を確認する
データの種類を区別する
モデルを構築する前に用意したデータを以下の3つのデータに分けておきます。
- 訓練データ
モデルの構築に使用 - 検証データ
モデルの精度を検証し改善するために使用 - テストデータ
予測モデルを最終的にテストする際に使用
これらの各データは以下の順序で使用します。
- 訓練データを使って予測モデルを作成
- 検証データを用いてうまくいかない場合の原因究明及び改善の実施によりモデルの精度を高める
- 訓練データと検証データによって完成した予測モデルに対し、テストデータを使用してテストを実施
このように、検証データとテストデータを行い2段階でチェックを実施することにより過学習に陥っていることに気づきやすくなります。
ホールドアウト法で検証する
ホールドアウト法は機械学習におけるデータテスト手法の一つです。
ホールドアウト法とは、訓練データと検証データ、テストデータの3つに分けて、モデルを作成するごとに検証を実施し分析をする手法のことです。
この手法のメリットとしては、実装が簡単でコンピュータへの負担をかけることなく実施できる点にあります。
基本的には、2つのデータの平均値や中央値などを見て、両データが同じ傾向を持っているかを確認することで過学習かを見分けることができます。
交差検証法で検証する
交差検証法はホールドアウト法と同様に、機械学習におけるデータテスト手法の一つです。
交差検証法とは、ホールドアウト法と同じように訓練データと検証データ、テストデータの3つに分けて実施する手法です。
ホールドアウト法では1通りの分割データを使ってテストを行いますが、交差検証法では複数の分割方法を試して全体の平均をとるという違いがあります。
複数の分割方法を試すことで、データの傾向の違いから生じるか学習を最小化することができます。
交差検証法はホールドアウト法よりも信頼性の高い結果を得られる手法ですが、元データの量が多く、コンピュータへの負荷もかかるため注意が必要です。
学習曲線を確認する
学習曲線は、訓練データの精度と検証データの精度を曲線で表し、過学習が起きていないか判断をする方法の一つです。
学習曲線をみることで判断できることは以下の2つになります。
- サンプル数の不足を確認できる
- モデルの過学習がおきていないかを確認できる
訓練データと検証データの曲線のギャップが大きい場合は過学習が起きている可能性が高くなります。
2つの曲線のギャップに注目し、このギャップが大きければ大きいほど予測モデルとしては使えません。
過学習を防ぐ方法は?
過学習を防ぐ方法はいくつかあります。
ここでは、以下の3つの防止策について見ていきましょう。
- データ量を増やす
- 正規化を行う
- アンサンブル学習を実施する
データ量を増やす
学習データの数は過学習に大きく関わる重要な部分です。
学習データの数が多ければ多いほど、データのバリエーションが増えるため未知のデータ(新しいデータ)に対応できるようになります。
しかし、AIモデルの構築に必要な十分な質と量のデータを用意するにはコストがかかる場合があります。
場合によっては、既存の学習データを拡張し複数の学習データのパターンを作成するという方法をとる必要があるかもしれません。
正規化を行う
正規化とは、複雑化したモデルを単純なモデルへと戻していくという手法で過学習を防ぐ対策の一つです。
正規化には主に「L1正規化」と「L2正規化」の2つの手法があります。
L1正規化
L1正規化は、特定の特徴の係数をゼロにすることで、モデルから不必要な特徴を取り除く方法です。この方法により、モデルの複雑さが減少し、過学習を防げます。特に、多数の特徴が存在する場合に効果的です。
L2正規化
L2正規化は、モデルの重みの大きさを制限することで、複雑な要素を抑えて過学習を防ぎます。特にデータの数が少ない場合に有効で、L1正規化を用いたモデルと比較して、予測精度が高くなる傾向にあります。
アンサンブル学習を実施する
アンサンブル学習とは、重回帰分析や決定木分析などの複数の予測モデルを組み合わせ、最終予測結果を出す手法です。
アンサンブル学習に使われる主な手法は「バギング」、「ブースティング」、「スタッキング」の3つがあります。
複数のモデルを使用することによって効果的に過学習を防ぐことができます。
バギング
バギングは、互いに独立した複数の学習器を並列に作成し予測を平均化して、1つの優れた予測を作る手法です。モデルの分散が減少されるため、過学習が抑制されます。
バギングの詳細は下記の記事で解説しているので、ぜひご参照ください。
ブースティング
ブースティングは、複数の学習器を使用して順番に学習して1つの予測モデルを作成する手法です。予測のバイアスを減少できるため、よりより精度が高まります。
スタッキング
スタッキングは、第一段階で様々なアルゴリズムを学習させ予測値を出力し、第二段階で予測値を特徴量として学習する手法です。この二段階のアプローチにより、異なるモデルの予測の強みを組み合わせることができます。
まとめ:過学習は事前に防止できる
本記事では、過学習の概要から、過学習が起こる原因、対策について紹介をしました。
AIモデルの構築において、訓練データの質と量は非常に重要であることがおわかりいただけたかと思います。
過学習に陥ると正確な予測結果を得られることができず、実運用することが難しくなります。
このような事態に陥らないためにも、本記事でご紹介した過学習に陥る原因を理解し、対策をするスキルが必要です。
適切な手法で検証およびテストを行うことで過学習は事前に防ぐことができます。
AIの活用に関するご相談は、是非一度Jiteraへご相談ください。
Jiteraでは要件定義を書くだけでAIがアプリ・システムを開発できるプラットフォームがあります。
お客様の環境に合わせて最適な活用方法をご提案いたします。