#使用信息熵寻找最优划分import numpy as npimport matplotlib.pyplot as plt
from sklearn import datasetsiris = datasets.load_iris()X = iris.data[:,2:]y = iris.targety.shape
from sklearn.tree import DecisionTreeClassifierdt_clf = DecisionTreeClassifier(max_depth = 2,criterion = "entropy")dt_clf.fit(X,y)
def plot_decision_boundary(model, axis):x0, x1 = np.meshgrid(np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1, 1),np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1, 1),)X_new = np.c_[x0.ravel(), x1.ravel()]y_predict = model.predict(X_new)zz = y_predict.reshape(x0.shape)from matplotlib.colors import ListedColormapcustom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)
plot_decision_boundary(dt_clf,axis=[0.5,7.5,0,3])plt.scatter(X[y == 0,0],X[y == 0,1])plt.scatter(X[y == 1,0],X[y == 1,1])plt.scatter(X[y == 2,0],X[y == 2,1])
#模拟使用信息熵进行划分def split(X,y,d,value):index_a = (X[:,d] <= value)index_b = (X[:,d] > value)return X[index_a],X[index_b],y[index_a],y[index_b]
from collections import Counterfrom math import logdef entropy(y):counter = Counter(y)res = 0.0for num in counter.values():p = num / len(y)res += -p * log(p)return resdef try_split(X,y):best_entropy = float('inf')best_d,best_v = -1,-1for d in range(X.shape[1]):sorted_index = np.argsort(X[:,d])for i in range(1,len(X)):if X[sorted_index[i - 1],d] != X[sorted_index[i],d]:v = (X[sorted_index[i - 1],d] + X[sorted_index[i],d]) / 2X_l,X_r,y_l,y_r = split(X,y,d,v)e = entropy(y_l) + entropy(y_r)if e < best_entropy:best_entropy,best_d,best_v = e, d, vreturn best_entropy,best_d,best_v
best_entropy,best_d,best_v = try_split(X,y)print('best_entropy = ',best_entropy)print('best_d = ',best_d)print('best_v = ',best_v)#第0个维度,最佳划分是2.45
best_entropy = 0.6931471805599453best_d = 0best_v = 2.45
X1_l,X1_r,y1_l,y1_r = split(X,y,best_d,best_v)
entropy(y1_l)
0.0
entropy(y1_r)
0.6931471805599453
best_entropy,best_d,best_v = try_split(X,y)print("best_entropy = ",best_entropy)print("best_d = ",best_d)print("best_v = ",best_v)
best_entropy = 0.6931471805599453best_d = 0best_v = 2.45
best_entropy2,best_d2,best_v2 = try_split(X1_r,y1_r)print("best_entropy2 = ",best_entropy2)print("best_d2 = ",best_d2)print("best_v2 = ",best_v2)
best_entropy2 = 0.4132278899361904best_d2 = 1best_v2 = 1.75