본문 바로가기
데이터분석/데이터분석

K-NN (K -Nearest Neighbor)

by 이규승 2022. 5. 16.
728x90
반응형

레이블이 있는 데이터를 사용하여 분류 작업을 하는 알고리즘이다. 데이터로부터 거리가 가까운 k개의
다른 데이터의 레이블을 참조하여 분류한다. 대개의 경우에 유클리디안 거리 계산법을 사용하여 거리를
측정하는데, 벡터의 크기가 커지면 계산이 복잡해진다.

 

KNN 이론 참고 사이트 : https://smecsm.tistory.com/53

 

KNN(K Neighbor Nearest)이란?

KNN ( K-Nearest Neighbor) In [1]: #주피터 노트북 블로그게시용 함수 from IPython.core.display import display, HTML display(HTML(" ")) KNN ( K -Nearest Neighbor) 개요¶ KNN은 K - 최근접 이웃법으로 분..

smecsm.tistory.com

K-NN 모델사용

# KNN : 데이터로부터 거리가 가까운 k개의 다른 데이터의 레이블을 참조하여 분류를 진행

train = [
    [5, 3, 2],
    [1, 3, 5],
    [4, 5, 7]
]

label = [0, 1, 1]

# import matplotlib.pyplot as plt
# plt.xlim([-1, 3])
# plt.ylim([0, 10])
# plt.plot(train, 'o')
# plt.show()

from sklearn.neighbors import KNeighborsClassifier

kmodel = KNeighborsClassifier(n_neighbors = 3, weights = 'distance')
kmodel.fit(train,label)
pred = kmodel.predict(train)
print('pred:',pred)

print('----------------------------')
from sklearn.datasets import load_breast_cancer
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
plt.rc('font', family = 'malgun gothic')

cancer = load_breast_cancer()
x_train, x_test, y_train, y_test = train_test_split(cancer.data, cancer.target,
                                                    stratify=cancer.target,
                                                    random_state = 66)
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)#(426, 30) (143, 30) (426,) (143,)

train_acc = []
test_acc = []
neighbors_set = range(1, 11)

for n_nei in neighbors_set:
    clf = KNeighborsClassifier(n_neighbors=n_nei)
    clf.fit(x_train, y_train)
    train_acc.append(clf.score(x_train,y_train)) # train데이터의 정확도
    test_acc.append(clf.score(x_test,y_test))
    
import numpy as np
print(train_acc)
print('train 분류 평균 정확도:', np.mean(train_acc))    
print('test 분류 평균 정확도:', np.mean(test_acc))    
print(test_acc)

plt.plot(neighbors_set, train_acc, label='훈련 정확도')
plt.plot(neighbors_set, test_acc, label='검증 정확도')
plt.xlabel('k값')
plt.ylabel('정확도')
plt.legend()
plt.show() # k 값이 6 인 경우가 가장 합리적인 모델이다.

 

728x90

'데이터분석 > 데이터분석' 카테고리의 다른 글

군집 분석(Clustering)  (0) 2022.05.16
인공신경망 (ANN)  (0) 2022.05.16
나이브 베이즈  (0) 2022.05.16
주성분 분석(PCA)  (0) 2022.05.14
서포터 벡터 머신 (SVM)  (0) 2022.05.14