데이터분석/데이터분석

sklearn을 이용한 과적합 방지

이규승 2022. 5. 11. 21:56
728x90

1. test / train

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

iris = load_iris() # iris 이용하기
train_data = iris.data
train_label = iris.target

# 분류모델
dt_clf = DecisionTreeClassifier() # sklearn의 다른 분류모델을 써도 됌
dt_clf.fit(train_data, train_label)
pred = dt_clf.predict(train_data)
print('분류 정확도:',accuracy_score(train_label, pred)) #(실제값, 예측값)

# 정확도 100%는 과적합 문제를 의심해라

# 과적합 방지 1 (test/ train)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target,
                                                     test_size=0.3, random_state=121)
dt_clf.fit(x_train, y_train) # 모델학습은 train data
pred2 = dt_clf.predict(x_test) # 모델평가는 test data
print('분류 정확도:',accuracy_score(y_test, pred2)) #0.9555

2. kfold

# 과적합 방지 2 (k-fold 이용)
# 모델 학습 시 데이터의 편중을 방지하고자 학습 데이터를 쪼개 학습과 평가를 병행
from sklearn.model_selection import KFold
import numpy as np
features = iris.data
label = iris.target
dt_clf = DecisionTreeClassifier(criterion = 'entropy', random_state = 0)
kfold = KFold(n_splits = 5) # 5번 나눈다.
cv_acc = []
# 전체 행 수가 150, 학습데이터 : 4/5(120개), 검증데이터 : 1/5(30개)로 분할해 가며 학습을 진행

n_iter = 0
for train_index, test_index in kfold.split(features):
    xtrain, xtest = features[train_index], features[test_index]
    ytrain, ytest = label[train_index], label[test_index]
    
    dt_clf.fit(xtrain, ytrain)
    pred = dt_clf.predict(xtest)
    n_iter += 1
    
    # 반복할 때 마다 정확도 측정
    acc = np.round(accuracy_score(ytest,pred),3)
    train_size = xtrain.shape[0]
    test_size = xtest.shape[0]
    print('반복수:{0}, 교차검증 정확도:{1}, 학습데이터 크기:{2}, 학습데이터 크기:{3}'.format(n_iter, acc, train_size, test_size))
    print('반복수:{0}, validation data index:{1}'.format(n_iter,test_index))
    cv_acc.append(acc)
print('평균 검증 정확도:',np.mean(cv_acc))

3. 불균형 데이터인 경우는 KFold 보다는 StratifiedKFold를 사용한다.

for train_index, test_index in skfold.split(features, label): # label 주는거만 다르다
    '''    
    print('n_iter',n_iter)
    print('train_index',len(train_index))
    print('test_index',len(test_index))
    n_iter += 1'''
    xtrain, xtest = features[train_index], features[test_index]
    ytrain, ytest = label[train_index], label[test_index]
    
    dt_clf.fit(xtrain, ytrain)
    pred = dt_clf.predict(xtest)
    n_iter += 1
    
    # 반복할 때 마다 정확도 측정
    acc = np.round(accuracy_score(ytest,pred),3)
    train_size = xtrain.shape[0]
    test_size = xtest.shape[0]
    print('반복수:{0}, 교차검증 정확도:{1}, 학습데이터 크기:{2}, 학습데이터 크기:{3}'.format(n_iter, acc, train_size, test_size))
    print('반복수:{0}, validation data index:{1}'.format(n_iter,test_index))
    cv_acc.append(acc)
print('평균 검증 정확도:',np.mean(cv_acc))

4. crosss_val_score를 이용한 교차검증

from sklearn.model_selection import cross_val_score

data = iris.data
label = iris.target

score = cross_val_score(dt_clf, data, label, scoring = 'accuracy', cv=5) # 5번 시행
print('교차 검증별 정확도:', np.round(score,3))
print('교차 검증별 정확도:', np.round(np.mean(score),3))

5. GridSearchCV : 교차검증과 최적의 속성(하이퍼 파라미터)을 위한 튜닝을 한 번에 처리

from sklearn.model_selection import GridSearchCV

# 여러 개의 속성 값 중 max_depth, min_samples_split에 대하여 최적의 값 찾기
parameters = {'max_depth':[1,2,3], 'min_samples_split':[2,3]}

grid_tree = GridSearchCV(dt_clf, param_grid = parameters, cv = 3, refit = True)
#cv kfold수 refit은 재학습
grid_tree.fit(x_train, y_train)

import pandas as pd
scores_df = pd.DataFrame(grid_tree.cv_results_)
# pd.set_option('max_columns', None)
# print(scores_df)

print('GridSearchCV 최적 파라미터 : ', grid_tree.best_params_)
print('GridSearchCV 최적 정확도 : ', grid_tree.best_score_)
#dt_clf = DecisionTreeClassifier(..., max_depth = 3, min_samples_split=2, ...)

# GridSearchCV가 제공하는 최적의 파라미터로 모델(DecisionTreeClassifier) 생성
estimator = grid_tree.best_estimator_
print(estimator)
pred = estimator.predict(x_test)
print(pred)
print('모델성능(정확도):',accuracy_score(y_test, pred))
728x90

'데이터분석 > 데이터분석' 카테고리의 다른 글

Regressor  (0) 2022.05.11
앙상블 (Esemble Learning)  (0) 2022.05.11
분류모델 : 의사결정나무(DecisionTree)  (0) 2022.05.11
혼동행렬, ROC, AUC  (0) 2022.05.11
로지스틱 회귀분석 (Logistic Regression)  (0) 2022.05.09