소소하지만 소소하지 않은 개발 공부/머신 러닝 교과서

10.4 RANSAC을 사용하여 안정된 회귀 모델 훈련, python

still..epochs 2022. 12. 23. 14:42

* 본 포스팅은 머신러닝교과서를 참조하여 작성되었습니다.

* https://github.com/rickiepark/python-machine-learning-book-3rd-edition

 

GitHub - rickiepark/python-machine-learning-book-3rd-edition: <머신 러닝 교과서 3판>의 코드 저장소

<머신 러닝 교과서 3판>의 코드 저장소. Contribute to rickiepark/python-machine-learning-book-3rd-edition development by creating an account on GitHub.

github.com

 

선형 회귀 모델은 이상치(outlier)에 크게 영향을 받을 수 있다. 이상치를 제거하려면 항상 해당 분야의 지식만 아니라 데이터 과학자로서 식견도 필요하다.

 

그렇다면 이상치를 제거하는 방식 대신 RANSAC(RANdom SAmple Consensus) 알고리즘을 사용하는 안정된 회귀 모델에 대해 알아보자.

 

반복적인 RANSAC 알고리즘을 다음과 같이 정리할 수 있다.

  1. 랜덤하게 일부 샘플을 정상치로 선택하여 모델을 훈련한다.
  2. 훈련된 모델에서 다른 모든 포인트를 테스트한다. 사용자가 입력한 허용 오차 안에 속한 포인트를 정상치에 추가한다
  3. 모든 정상치를 사용하여 모델을 다시 훈련한다.
  4. 훈련된 모델과 정상치 간의 오차를 추정한다.
  5. 성능이 사용자가 지정한 임계 값에 도달하거나 지정된 반복 횟수에 도달하면 알고리즘을 종료한다. 그렇지 않으면 단계 1로 돌아간다.
from sklearn.linear_model import RANSACRegressor
ransac = RANSACRegressor(LinearRegression(), max_trials=100, min_samples=50, loss='absolute_loss',
                        residual_threshold=5.0, random_state=0)
ransac.fit(X, y)
  • RANSACRegressor 의 최대 반복 횟수 100
  • min_samples=50 최소 샘플 개수 50개
  • loss = 'absolute_loss', 학습한 직선과 샘플 포인트 간 수직 거리의 절댓값
  • residual_threshold = 5.0, 학습한 직선과 수직 거리가 5 이내에 있는 정상 샘플만 포함
  •  
inlier_mask = ransac.inlier_mask_
outlier_mask = np.logical_not(inlier_mask)
line_X = np.arange(3, 10, 1)
line_y_ransac = ransac.predict(line_X[:, np.newaxis])
plt.scatter(X[inlier_mask], y[inlier_mask], c='steelblue', edgecolor='white', marker='o', label='Inliers')
plt.scatter(X[outlier_mask], y[outlier_mask], c='limegreen', edgecolor='white', marker='s', label='Outliers')
plt.plot(line_X, line_y_ransac, color='black', lw=2)
plt.xlabel('Average number of rooms [RM]')
plt.ylabel('Price in $1000s [MEDV]')
plt.legend(loc='upper left')
plt.show()

RANSAC으로 학습한 선형 모델

산점도에서 볼 수 있듯이 동그라미로 표시된 정상치에 선형 회귀 모델이 훈련되었다.

 

다음 코드로 이 모델의 기울기와 절편을 출력하면 이전 절에서 RANSAC을 사용하지 않고 구한 직선과 조금 다른 것을 알 수 있다.

print('기울기: %.3f' % ransac.estimator_.coef_[0])
print('절편: %.3f' % ransac.estimator_.intercept_)


>> 기울기: 10.735
   절편: -44.089

 

RANSAC을 사용하면 데이터셋에 있는 이상치의 잠재적인 영향을 감소시킨다. 하지만 이 방법이 본 적 없는 데이터에 대한 예측 성능에 긍정적인 영향을 미치는지 미치지 못하는지 알 지 못한다.