23 lines
543 B
Python
23 lines
543 B
Python
|
from sklearn import tree
|
||
|
|
||
|
# 特征数据 身高 体重
|
||
|
# 1 - 有
|
||
|
# 0 - 无
|
||
|
features = [[178, 1], [155, 0], [180, 1]]
|
||
|
# 特征值
|
||
|
labels = ['male', 'female', 'male']
|
||
|
|
||
|
def decision_tree_classifier():
|
||
|
# 创建分类器
|
||
|
clf = tree.DecisionTreeClassifier()
|
||
|
# 模型训练
|
||
|
clf = clf.fit(features, labels)
|
||
|
# 预测
|
||
|
r1 = clf.predict([[158, 0]])
|
||
|
print('Data[158, 0] is label for: ', r1)
|
||
|
|
||
|
r2 = clf.predict([[190, 1]])
|
||
|
print('Data[190, 1] is label for: ', r2)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
decision_tree_classifier()
|