이번에 살펴볼 개념은 k-fold Cross-validation입니다. 앞서 다뤘던 Validation Set ApproachLeave-One-Out Cross-Validation에 이어 마지막 validation 방법입니다. 사실 개념은 LOOCV와 크게 다르지 않으니 쉽게 이해하실 수 있을 것입니다. 먼저 아래의 그림을 보시죠.

img

이렇게 LOOCV는 각 하나의 샘플씩 $N$번을 반복했다면, k-fold CV는 데이터셋을 k개의 같은 크기로 나눈 다음에 하나의 부분씩을 test set으로 사용하여 k개의 test performance를 평균내는 것을 의미합니다. 이를 수식으로 나타내면, \[CV_{(k)}=\frac{1}{k}\sum_{i=1}^k \text{MSE}_i\] 어떻게 보면 양 극단인 validation set approach와 LOOCV의 중간 정도라고 보시면 될 것 같습니다. 계산적으로도 LOOCV에 비해 훨씬 빠르면서, 그래도 k개의 test set에 대해서 평균을 내기 때문에 validation set approach 보다는 더 안정적인 test error를 가져올 수 있죠. 그래서 많이들 5-fold나 10-fold를 사용합니다. 앞서 사용했던 Auto dataset을 사용해 보겠습니다.

library(ISLR)
library(boot) #cv.glm을 위한 라이브러리

앞서 말씀드렸던 cv.glm의 함수 내에 K라는 인자를 건들지 않으면 자동적으로 LOOCV를 계산하게 되는데요, 만약 우리가 이 K값을 바꿔준다면 k-fold CV를 계산해 줍니다. 먼저 5-fold CV를 1~5차함수로 fitting하고 그 결과를 구해 보겠습니다.

degree=1:5
cv.error5=rep(0,5)
for(d in degree){
  glm.fit=glm(mpg~poly(horsepower,d), data=Auto)
  cv.error5[d]=cv.glm(Auto,glm.fit,K=5)$delta[1] # K값만 5로 명시해 주면 5-fold CV가 계산됩니다.
}

훨씬 계산이 빠름을 확인하실 수 있죠? 그럼 이번엔 10-fold CV를 계산해 보겠습니다.

cv.error10=rep(0,5)
for(d in degree){
  glm.fit=glm(mpg~poly(horsepower,d), data=Auto)
  cv.error10[d]=cv.glm(Auto,glm.fit,K=10)$delta[1]
}

마지막으로 비교를 위해서 앞선 Leave-One-Out Cross-Validation에서 구했던 결과를 같이 plot해 보겠습니다.

loocv=function(fit){
  h=lm.influence(fit)$h
  mean((residuals(fit)/(1-h))^2)
}

cv.error=rep(0,5)
for(d in degree){
  glm.fit=glm(mpg~poly(horsepower,d), data=Auto)
  cv.error[d]=loocv(glm.fit)
}

#plot
plot(degree,cv.error5,type="b",col="blue", ylab="CV Error")
lines(degree,cv.error10,type="b",col="red")
lines(degree,cv.error,type="b")

#legend를 통하여 범례를 만들어 줍니다.
legend("topright", c("5-fold CV", "10-fold CV", "LOOCV"), pch=1, col=c('blue', 'red', 'black'))

center

여러번 반복해 보시면 아시겠지만, LOOCV는 randomness가 존재하지 않기 때문에 할 때마다 동일한 결과를 얻습니다.(Stable) 반면, k-fold CV는 k값이 작을수록 할 때마다 변동이 크게 되어 있습니다. 그럼에도 불구하고 k-fold를 많이 쓰는 것은 단순히 계산적인 속도의 이득뿐만이 아니라, LOOCV의 경우는 training data간의 상관관계가 매우 높을 수밖에 없기 때문에 이에 의해 발생될 수 있는 overfitting의 가능성도 줄이고자 하는 목적도 있기 때문입니다.

모든 model fitting에 필수적인 validation 방법이니 반드시 숙지하시고 잘 사용하시기 바랍니다 :)