1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > Python机器学习:决策树003使用信息熵寻找最优划分

Python机器学习:决策树003使用信息熵寻找最优划分

时间:2018-11-18 16:37:50

相关推荐

Python机器学习:决策树003使用信息熵寻找最优划分

#使用信息熵寻找最优划分import numpy as npimport matplotlib.pyplot as plt

from sklearn import datasetsiris = datasets.load_iris()X = iris.data[:,2:]y = iris.targety.shape

from sklearn.tree import DecisionTreeClassifierdt_clf = DecisionTreeClassifier(max_depth = 2,criterion = "entropy")dt_clf.fit(X,y)

def plot_decision_boundary(model, axis):x0, x1 = np.meshgrid(np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1, 1),np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1, 1),)X_new = np.c_[x0.ravel(), x1.ravel()]y_predict = model.predict(X_new)zz = y_predict.reshape(x0.shape)from matplotlib.colors import ListedColormapcustom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)

plot_decision_boundary(dt_clf,axis=[0.5,7.5,0,3])plt.scatter(X[y == 0,0],X[y == 0,1])plt.scatter(X[y == 1,0],X[y == 1,1])plt.scatter(X[y == 2,0],X[y == 2,1])

#模拟使用信息熵进行划分def split(X,y,d,value):index_a = (X[:,d] <= value)index_b = (X[:,d] > value)return X[index_a],X[index_b],y[index_a],y[index_b]

from collections import Counterfrom math import logdef entropy(y):counter = Counter(y)res = 0.0for num in counter.values():p = num / len(y)res += -p * log(p)return resdef try_split(X,y):best_entropy = float('inf')best_d,best_v = -1,-1for d in range(X.shape[1]):sorted_index = np.argsort(X[:,d])for i in range(1,len(X)):if X[sorted_index[i - 1],d] != X[sorted_index[i],d]:v = (X[sorted_index[i - 1],d] + X[sorted_index[i],d]) / 2X_l,X_r,y_l,y_r = split(X,y,d,v)e = entropy(y_l) + entropy(y_r)if e < best_entropy:best_entropy,best_d,best_v = e, d, vreturn best_entropy,best_d,best_v

best_entropy,best_d,best_v = try_split(X,y)print('best_entropy = ',best_entropy)print('best_d = ',best_d)print('best_v = ',best_v)#第0个维度,最佳划分是2.45

best_entropy = 0.6931471805599453best_d = 0best_v = 2.45

X1_l,X1_r,y1_l,y1_r = split(X,y,best_d,best_v)

entropy(y1_l)

0.0

entropy(y1_r)

0.6931471805599453

best_entropy,best_d,best_v = try_split(X,y)print("best_entropy = ",best_entropy)print("best_d = ",best_d)print("best_v = ",best_v)

best_entropy = 0.6931471805599453best_d = 0best_v = 2.45

best_entropy2,best_d2,best_v2 = try_split(X1_r,y1_r)print("best_entropy2 = ",best_entropy2)print("best_d2 = ",best_d2)print("best_v2 = ",best_v2)

best_entropy2 = 0.4132278899361904best_d2 = 1best_v2 = 1.75

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。