转载

决策树

决策树

决策树是一种有监督的分类模型,由节点和边构成。节点分为两种,一种是

树的内部节点,表示一个特征或者属性,一种是叶节点,表示一个类(标签),

边代表特征的各种取值。如下图所示.

用决策树进行分类时,从根节点开始,对实例的某一特征进行测试,根据测试结果,

将实例分配到其子节点(每个子节点对应着该特征的一个取值),根据这个策略一直递归下去,

直到到达叶节点,最后将实例分到叶节点的类中。

比如上图,一个人(实例)为**(拥有房产,未婚,年收入<80)**, 就会按照蓝色虚线轨迹,被分到

无法偿还的类中。

如何构造决策树

信息增益

可以想象,如何构造决策树的关键是每次选出最优的特征划分实例集(数据集)。那么问题就转化为

每次如何选出最优特征划分实例集(数据集),如果要解决这个问题,首先了解一下熵的概念

熵是表示随机变量不确定性的度量

设随机变量 X X 是一个离散随机变量,其概率分布为

P ( X = x i ) = p i , i = 1 , 2 , . . . . , n P(X=x_i)=p_i, i=1,2,....,n

则其熵定义为

H ( X ) = ? i = 1 n p i ? l o g p i H(X)=-\sum\limits_{i=1}^np_i * logp_i

熵越大,变量的不确定性越大

条件熵

H ( Y X ) = i = 1 n p i H ( Y X = x i ) ? H(Y|X)=\sum\limits_{i=1}^np_iH(Y|X=x_i)?

p i = P ( X = x i ) , i = 1 , 2 , . . . n p_i=P(X=x_i), i=1,2,...n

信息增益表示得知特征 X X 的信息而使得类 Y Y 的信息的不确定减少的程度。

特征 A A 对训练数据集 D D 的信息增益 g ( D , A ) g(D,A) ,定义为集合 D D 的经验熵 H ( D ) H(D)

与特征 A A 给定条件下 D D 的经验条件熵 H ( D A ) H(D|A) 之差

g ( D , A ) = H ( D ) ? H ( D A ) g(D,A)=H(D)-H(D|A)

所以我们选择特征的时候,选择信息增益能最大的特征划分实例集

ID3算法

ID3算法是最简单的构建决策树算法,具体方法是:从根节点开始,对节点

计算所有可能的特征的信息增益,选择信息增益最大的特征作为节点的特征

由该特征的不同取值建立子节点,再对子节点递归地调用以上方法,构建决策

树,算法终止的条件是:没有特征可以再选择,或者所有叶节点都分别属于同一类

代码

下载链接https://download.csdn.net/download/zycxnanwang/10981838

def createPlot(inTree):
	#print (inTree)

	fig = plt.figure(1, facecolor='white')
	fig.clf()
	axprops = dict(xticks=[], yticks=[])
	createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
	plotTree.totalW = float(getNumLeafs(inTree))
	plotTree.totalD = float(getTreeDepth(inTree))
	plotTree.xOff = -0.5/plotTree.totalW
	plotTree.yOff = 1.0
	plotTree(inTree, (0.5, 1.0), '')
	plt.show()

plotTree.totalW是树的叶节点数量

plotTree.totalD是树的深度

plotTree.yOff=1.0容易理解,画布为1.0*1.0,所以初始高度为1.0

plot.xOff=-0.5/plotTree.totalW可能难以理解,不是plot.xOff=0吗?

首先我们可以从代码中知道x轴的划分是按照叶节点数量划分的,所以第一个叶节点的

位置为1/plotTree.totalW,第二个叶节点的位置为2/plotTree.totalW…,可以发现叶节点的

位置都是靠在它们空间中的最右端,为了在中间所以加上了一个**-0.5/plotTree.totalW**

def plotTree(myTree, parentPt, nodeTxt):
	numLeafs = getNumLeafs(myTree)
	#depth = getTreeDepth(myTree)
	firstStr = list(myTree.keys())[0] 
	cntrPt = (plotTree.xOff +  1/2 * float(numLeafs)/plotTree.totalW + 0.5/plotTree.totalW, \
	plotTree.yOff) 
	#print (cntrPt)
	plotMidText(cntrPt, parentPt, nodeTxt)
	plotNode(firstStr, cntrPt, parentPt, decisionNode)
	secondDict = myTree[firstStr]
	plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			plotTree(secondDict[key], cntrPt, str(key))
		else:
			plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
			plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
				cntrPt, leafNode)
			plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
	plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

plotTree.xOff + 1/2 * float(numLeafs)/plotTree.totalW + 0.5/plotTree.totalW可能难以理解

中间位置的定义:子树叶节点的中间,plotTree.xOff减了一个0.5/plotTree.totalW,所以必须加上。

正文到此结束
本文目录