데이터분석/데이터분석

분류모델 : 의사결정나무(DecisionTree)

이규승 2022. 5. 11. 21:45
728x90

classification, regression 모두 가능하나 분류모델로 더 많이 사용됨
Decision Tree는 여러 가지 규칙을 순차적으로 적용하면서 독립 변수 공간을 분할하는 분류 모형이다.

의사결정나무 참고 사이트 : https://ratsgo.github.io/machine%20learning/2017/03/26/tree/

 

의사결정나무(Decision Tree) · ratsgo's blog

이번 포스팅에선 한번에 하나씩의 설명변수를 사용하여 예측 가능한 규칙들의 집합을 생성하는 알고리즘인 의사결정나무(Decision Tree)에 대해 다뤄보도록 하겠습니다. 이번 글은 고려대 강필성

ratsgo.github.io

코드 참고 사이트 : https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

 

sklearn.tree.DecisionTreeClassifier

Examples using sklearn.tree.DecisionTreeClassifier: Classifier comparison Classifier comparison, Plot the decision surface of decision trees trained on the iris dataset Plot the decision surface of...

scikit-learn.org

import collections
from sklearn import tree

# setting
x = [[180,15],[177,42],[156,35],[174,5],[166,33],[170,12],[171,7]]
y = ['man', 'woman', 'woman', 'man', 'woman', 'man', 'woman']
label_names = ['height', 'hair length']

# 의사결정나무
model = tree.DecisionTreeClassifier(criterion = 'entropy', random_state = 0)
# criterion은 gini, entropy 2가지가 있다.
model.fit(x,y)
pred = model.predict(x)
mydata = [[171,18]] # new data 입력
new_pred = model.predict(mydata) # new predict
print('분류 예측 결과:',new_pred)

의사결정나무 시각화 참고 사이트 : https://cafe.daum.net/flowlife/SBU0/13

 

Daum 카페

 

cafe.daum.net

import pydotplus

dot_data = tree.export_graphviz(model, feature_names = label_names,
                                out_file =None, filled = True, rounded = True)
graph = pydotplus.graph_from_dot_data(dot_data)
colors = ('red', 'orange')
edges = collections.defaultdict(list)
print(edges, type(edges)) #defaultdict(<class 'list'>, {}) <class 'collections.defaultdict'>

for e in graph.get_edge_list():
    edges[e.get_source()].append(int(e.get_destination()))
    
print(edges)

for e in edges:
    edges[e].sort()
    for i in range(2):
        dest = graph.get_node(str(edges[e][i]))[0]
        dest.set_fillcolor(colors[i])

graph.write_png('tree.png')

import matplotlib.pyplot as plt
from matplotlib.pyplot import imread #imageread

img = imread('tree.png')
plt.imshow(img)
plt.show()

collections.defaultdict 내용

https://www.daleseo.com/python-collections-defaultdict/

 

[파이썬] 사전의 기본값 처리 (dict.setdefault / collections.defaultdict)

Engineering Blog by Dale Seo

www.daleseo.com

 

728x90