NLP

Kfold( Cross Validation 교차검증 )

still..epochs 2023. 8. 3. 22:45

일반적으로 데이터를 모델에 사용할 수 있도록 정제한 후에는 train과 test 셋으로 분할하여 모델을 학습 시킨 후, test 셋으로 모델의 성능평가를 진행한다. 하지만 데이터가 적거나 동일한 데이이터 내에서 계속해서 학습을 진행하는 경우, train 셋을 집중적으로 학습하여 train 셋의 데이터는 모델이 잘 예측하지만 새롭게 들어오는 데이터에 대한 예측은 부정확한 경우가 발생한다. 이러한 현상을 오버피팅(overfitting)이라고 부른다. 

 

이러한 문제점을 해결하기 위해, 여러가지 방법론이 존재하지만 오늘은 kfold에 대해 설명해보려고 한다.

출처 : https://scikit-learn.org/stable/modules/cross_validation.html

일반적으로 학습에 사용할 데이터는 Training data, Test data 로 분류하여 사용한다. 그런데 Training data를 Fold로 나누어 위 그림의 상황처럼 1:4 비율로 돌아가며 학습시킨다면 Training data를 학습할 때, 데이터를 적극적으로 fitting에 활용할 수 있다. 따라서 kfold 를 통해 모델의 완성도를 높일 수 있다.

 

코드

iris 데이터 셋을 활용하여 kfold를 사용해보자.

from sklearn.datasets import load_iris
import pandas as pd
import numpy as np

from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

# iris 데이터
iris = load_iris()
df = pd.DataFrame(iris.data, columns = iris.feature_names)
df['label'] = iris.target

df.head(5)

 

X = iris.data
y = iris.target
train_X, test_X, train_Y, test_Y = train_test_split(X, y, test_size=0.3)

lr = LogisticRegression()
lr.fit(train_X, train_Y)
r2_score = lr.score(test_X, test_Y)

print(r2_score)

>> 0.9333333333333333

 

단순히 kfold 없이 모델을 학습시켰을때는 약 0.93의 정확도를 얻을 수 있다.

그렇다면 kfold를 활용해보자

 

kfold =KFold(n_splits=5, shuffle=True, random_state=123)
scores = []

for k, (train, test) in enumerate(kfold.split(train_X, train_Y)):
    lr.fit(train_X[train], train_Y[train])
    score = lr.score(train_X[test], train_Y[test])
    scores.append(score)

    print('Fold: {}, 정확도: {:.3f}'.format(k+1, score))
    
>>  Fold: 1, 정확도: 0.952
	Fold: 2, 정확도: 0.952
	Fold: 3, 정확도: 1.000
	Fold: 4, 정확도: 0.952
	Fold: 5, 정확도: 1.000

 

이제 평균 정확도를 살펴보면

 

print('Kfold 정확도 점수: %.3f +/- %.3f' %(np.mean(scores), np.std(scores)))


>> Kfold 정확도 점수: 0.971 +/- 0.023

로 상승한 것을 살펴볼 수 있다.

'NLP' 카테고리의 다른 글

Confusion Matrix, Accuracy, Precision, Recall, F1 score  (0) 2023.08.02
Model 평가 및 지표들  (0) 2023.07.31
reset_index() 사용 방법  (0) 2023.07.27