1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > 【机器学习】 - 决策树(西瓜数据集)

【机器学习】 - 决策树(西瓜数据集)

时间:2019-03-04 20:05:52

相关推荐

【机器学习】 - 决策树(西瓜数据集)

周志华的西瓜书《决策树》部分的代码实现

#利用决策树算法,对mnist数据集进行测试import numpy as np#计算熵def calcEntropy(dataSet):mD = len(dataSet)dataLabelList = [x[-1] for x in dataSet]dataLabelSet = set(dataLabelList)ent = 0for label in dataLabelSet:mDv = dataLabelList.count(label)prop = float(mDv) / mDent = ent - prop * np.math.log(prop, 2)return ent# # 拆分数据集# # index - 要拆分的特征的下标# # feature - 要拆分的特征# # 返回值 - dataSet中index所在特征为feature,且去掉index一列的集合def splitDataSet(dataSet, index, feature):splitedDataSet = []mD = len(dataSet)for data in dataSet:if(data[index] == feature):sliceTmp = data[:index]sliceTmp.extend(data[index + 1:])splitedDataSet.append(sliceTmp)return splitedDataSet#根据信息增益 - 选择最好的特征# 返回值 - 最好的特征的下标def chooseBestFeature(dataSet):entD = calcEntropy(dataSet)mD = len(dataSet)featureNumber = len(dataSet[0]) - 1maxGain = -100maxIndex = -1for i in range(featureNumber):entDCopy = entDfeatureI = [x[i] for x in dataSet]featureSet = set(featureI)for feature in featureSet:splitedDataSet = splitDataSet(dataSet, i, feature) # 拆分数据集mDv = len(splitedDataSet)entDCopy = entDCopy - float(mDv) / mD * calcEntropy(splitedDataSet)if(maxIndex == -1):maxGain = entDCopymaxIndex = ielif(maxGain < entDCopy):maxGain = entDCopymaxIndex = ireturn maxIndex# 寻找最多的,作为标签def mainLabel(labelList):labelRec = labelList[0]maxLabelCount = -1labelSet = set(labelList)for label in labelSet:if(labelList.count(label) > maxLabelCount):maxLabelCount = labelList.count(label)labelRec = labelreturn labelRec#生成树def createDecisionTree(dataSet, featureNames):labelList = [x[-1] for x in dataSet]if(len(dataSet[0]) == 1): #没有可划分的属性了return mainLabel(labelList) #选出最多的label作为该数据集的标签elif(labelList.count(labelList[0]) == len(labelList)): # 全部都属于同一个Labelreturn labelList[0]bestFeatureIndex = chooseBestFeature(dataSet)bestFeatureName = featureNames.pop(bestFeatureIndex)myTree = {bestFeatureName: {}}featureList = [x[bestFeatureIndex] for x in dataSet]featureSet = set(featureList)for feature in featureSet:featureNamesNext = featureNames[:]splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)myTree[bestFeatureName][feature] = createDecisionTree(splitedDataSet, featureNamesNext)return myTree#读取西瓜数据集2.0def readWatermelonDataSet():ifile = open("周志华_西瓜数据集2.txt")featureName = ifile.readline() #表头labels = (featureName.split(' ')[0]).split(',')lines = ifile.readlines()dataSet = []for line in lines:tmp = line.split('\n')[0]tmp = tmp.split(',')dataSet.append(tmp)return dataSet, labelsdef main():#读取数据dataSet, featureNames = readWatermelonDataSet()print(createDecisionTree(dataSet, featureNames))if __name__ == "__main__":main()

最后输出的决策树是:

{‘纹理’: {‘模糊’: ‘否’, ‘清晰’: {‘根蒂’: {‘稍蜷’: {‘色泽’: {‘乌黑’: {‘触感’: {‘硬滑’: ‘是’, ‘软粘’: ‘否’}}, ‘青绿’: ‘是’}}, ‘蜷缩’: ‘是’, ‘硬挺’: ‘否’}}, ‘稍糊’: {‘触感’: {‘硬滑’: ‘否’, ‘软粘’: ‘是’}}}}

画出来是这个样子的:

这个地方和书上不太一样。

后来参考了一篇CSDN文章

说是需要补全决策树

后来又仔细看了伪代码

主要是对画红线处的理解。

这里的“每一个值”到底是原始数据集的?还是分割后的数据集的

上面的代码是后者,书上是前者

把createDecisionTree() 和 readWatermelonDataSet()函数修改为下面的:

#生成决策树# featureNamesSet 是featureNames取值的集合# labelListParent 是父节点的标签列表def createFullDecisionTree(dataSet, featureNames, featureNamesSet, labelListParent):labelList = [x[-1] for x in dataSet]if(len(dataSet) == 0):return mainLabel(labelListParent)elif(len(dataSet[0]) == 1): #没有可划分的属性了return mainLabel(labelList) #选出最多的label作为该数据集的标签elif(labelList.count(labelList[0]) == len(labelList)): # 全部都属于同一个Labelreturn labelList[0]bestFeatureIndex = chooseBestFeature(dataSet)bestFeatureName = featureNames.pop(bestFeatureIndex)myTree = {bestFeatureName: {}}featureList = featureNamesSet.pop(bestFeatureIndex)featureSet = set(featureList)for feature in featureSet:featureNamesNext = featureNames[:]featureNamesSetNext = featureNamesSet[:][:]splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)myTree[bestFeatureName][feature] = createFullDecisionTree(splitedDataSet, featureNamesNext, featureNamesSetNext, labelList)return myTree#读取西瓜数据集2.0def readWatermelonDataSet():ifile = open("周志华_西瓜数据集2.txt")featureName = ifile.readline() #表头featureNames = (featureName.split(' ')[0]).split(',')lines = ifile.readlines()dataSet = []for line in lines:tmp = line.split('\n')[0]tmp = tmp.split(',')dataSet.append(tmp)#获取featureNamesSetfeatureNamesSet = []for i in range(len(dataSet[0]) - 1):col = [x[i] for x in dataSet]colSet = set(col)featureNamesSet.append(list(colSet))return dataSet, featureNames, featureNamesSet

现在和书上的一样了

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。