如何从scikit-learn决策树中提取决策规则?
我可以从决策树中的训练树中提取底层决策规则(或“决策path”) – 作为文本列表吗?
例如: "if A>0.4 then if B<0.2 then if C>0.8 then class='X'
等等。
如果有人知道一个简单的方法,这将是非常有帮助的。
我相信这个答案比其他答案更正确:
from sklearn.tree import _tree def tree_to_code(tree, feature_names): tree_ = tree.tree_ feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature ] print "def tree({}):".format(", ".join(feature_names)) def recurse(node, depth): indent = " " * depth if tree_.feature[node] != _tree.TREE_UNDEFINED: name = feature_name[node] threshold = tree_.threshold[node] print "{}if {} <= {}:".format(indent, name, threshold) recurse(tree_.children_left[node], depth + 1) print "{}else: # if {} > {}".format(indent, name, threshold) recurse(tree_.children_right[node], depth + 1) else: print "{}return {}".format(indent, tree_.value[node]) recurse(0, 1)
这将打印出一个有效的Python函数。 下面是一个试图返回其input的树的示例输出,一个介于0和10之间的数字。
def tree(f0): if f0 <= 6.0: if f0 <= 1.5: return [[ 0.]] else: # if f0 > 1.5 if f0 <= 4.5: if f0 <= 3.5: return [[ 3.]] else: # if f0 > 3.5 return [[ 4.]] else: # if f0 > 4.5 return [[ 5.]] else: # if f0 > 6.0 if f0 <= 8.5: if f0 <= 7.5: return [[ 7.]] else: # if f0 > 7.5 return [[ 8.]] else: # if f0 > 8.5 return [[ 9.]]
以下是我在其他答案中看到的一些绊脚石:
- 使用
tree_.threshold == -2
来决定一个节点是否是一个叶是不是一个好主意。 如果它是一个门槛为-2的真实决策节点呢? 相反,你应该看看tree.feature
或tree.children_*
。 - 因为tree.tree_.feature的某些值是-2(特别是对于叶节点),所以line
features = [feature_names[i] for i in tree_.feature]
与我的sklearn版本一起崩溃。 - 在recursion函数中不需要有多个if语句,只要一个就可以。
我创build了自己的函数来从sklearn创build的决策树中提取规则:
import pandas as pd import numpy as np from sklearn.tree import DecisionTreeClassifier # dummy data: df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]}) # create decision tree dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1) dt.fit(df.ix[:,:2], df.dv)
这个函数首先从节点(由子数组中的-1标识)开始,然后recursion地find父节点。 我称之为节点的“血统”。 一路上,我抓住了我需要创buildif / then / else SAS逻辑的值:
def get_lineage(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] # get ids of child nodes idx = np.argwhere(left == -1)[:,0] def recurse(left, right, child, lineage=None): if lineage is None: lineage = [child] if child in left: parent = np.where(left == child)[0].item() split = 'l' else: parent = np.where(right == child)[0].item() split = 'r' lineage.append((parent, split, threshold[parent], features[parent])) if parent == 0: lineage.reverse() return lineage else: return recurse(left, right, parent, lineage) for child in idx: for node in recurse(left, right, child): print node
下面的元组集包含我需要创buildSAS if / then / else语句的所有内容。 我不喜欢在SAS中使用块,这就是为什么我创build描述节点整个path的逻辑。 元组之后的单个整数是path中terminal节点的ID。 所有前面的元组合起来创build该节点。
In [1]: get_lineage(dt, df.columns) (0, 'l', 0.5, 'col1') 1 (0, 'r', 0.5, 'col1') (2, 'l', 4.5, 'col2') 3 (0, 'r', 0.5, 'col1') (2, 'r', 4.5, 'col2') (4, 'l', 2.5, 'col1') 5 (0, 'r', 0.5, 'col1') (2, 'r', 4.5, 'col2') (4, 'r', 2.5, 'col1') 6
我修改了Zelazny7提交的代码来打印一些伪代码:
def get_code(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] value = tree.tree_.value def recurse(left, right, threshold, features, node): if (threshold[node] != -2): print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {" if left[node] != -1: recurse (left, right, threshold, features,left[node]) print "} else {" if right[node] != -1: recurse (left, right, threshold, features,right[node]) print "}" else: print "return " + str(value[node]) recurse(left, right, threshold, features, 0)
如果您在同一个示例中调用get_code(dt, df.columns)
,您将获得:
if ( col1 <= 0.5 ) { return [[ 1. 0.]] } else { if ( col2 <= 4.5 ) { return [[ 0. 1.]] } else { if ( col1 <= 2.5 ) { return [[ 1. 0.]] } else { return [[ 0. 1.]] } } }
from StringIO import StringIO out = StringIO() out = tree.export_graphviz(clf, out_file=out) print out.getvalue()
你可以看到一个二合字母树。 然后, clf.tree_.feature
和clf.tree_.value
分别是节点分裂特征数组和节点值数组。 你可以参考这个github源的更多细节。
在0.18.0版本中有一个新的DecisionTreeClassifier
方法, decision_path
。 开发人员提供了广泛的(有据可查的) 演练 。
打印树结构的演练中的代码的第一部分似乎是确定的。 不过,我修改了第二部分的代码来询问一个样本。 我的更改用# <--
sample_id = 0 node_index = node_indicator.indices[node_indicator.indptr[sample_id]: node_indicator.indptr[sample_id + 1]] print('Rules used to predict sample %s: ' % sample_id) for node_id in node_index: if leave_id[sample_id] == node_id: # <-- changed != to == #continue # <-- comment out print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <-- else: # < -- added else to iterate through decision nodes if (X_test[sample_id, feature[node_id]] <= threshold[node_id]): threshold_sign = "<=" else: threshold_sign = ">" print("decision id node %s : (X[%s, %s] (= %s) %s %s)" % (node_id, sample_id, feature[node_id], X_test[sample_id, feature[node_id]], # <-- changed i to sample_id threshold_sign, threshold[node_id])) Rules used to predict sample 0: decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921) decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927) leaf node 4 reached, no decision here
更改sample_id
以查看其他样本的决策path。 我没有向开发人员询问这些变化,只是通过这个例子看起来更直观。
这是一个函数,在python 3下打印scikit-learn决策树的规则,并使用条件块的偏移量来使结构更具可读性:
def print_decision_tree(tree, feature_names=None, offset_unit=' '): '''Plots textual representation of rules of a decision tree tree: scikit-learn representation of tree feature_names: list of feature names. They are set to f1,f2,f3,... if not specified offset_unit: a string of offset of the conditional block''' left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold value = tree.tree_.value if feature_names is None: features = ['f%d'%i for i in tree.tree_.feature] else: features = [feature_names[i] for i in tree.tree_.feature] def recurse(left, right, threshold, features, node, depth=0): offset = offset_unit*depth if (threshold[node] != -2): print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {") if left[node] != -1: recurse (left, right, threshold, features,left[node],depth+1) print(offset+"} else {") if right[node] != -1: recurse (left, right, threshold, features,right[node],depth+1) print(offset+"}") else: print(offset+"return " + str(value[node])) recurse(left, right, threshold, features, 0,0)
下面的代码是我在anaconda python 2.7下加上一个包名“pydot-ng”来制作带有决策规则的PDF文件的方法。 我希望这是有帮助的。
from sklearn import tree clf = tree.DecisionTreeClassifier(max_leaf_nodes=n) clf_ = clf.fit(X, data_y) feature_names = X.columns class_name = clf_.classes_.astype(int).astype(str) def output_pdf(clf_, name): from sklearn import tree from sklearn.externals.six import StringIO import pydot_ng as pydot dot_data = StringIO() tree.export_graphviz(clf_, out_file=dot_data, feature_names=feature_names, class_names=class_name, filled=True, rounded=True, special_characters=True, node_ids=1,) graph = pydot.graph_from_dot_data(dot_data.getvalue()) graph.write_pdf("%s.pdf"%name) output_pdf(clf_, name='filename%s'%n)
树graphy显示在这里
只是因为每个人都很有帮助,所以我只是给Zelazny7和Daniele的漂亮的解决scheme添加一个修改。 这是一个用于Python 2.7,带有选项卡,使其更具可读性:
def get_code(tree, feature_names, tabdepth=0): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] value = tree.tree_.value def recurse(left, right, threshold, features, node, tabdepth=0): if (threshold[node] != -2): print '\t' * tabdepth, print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {" if left[node] != -1: recurse (left, right, threshold, features,left[node], tabdepth+1) print '\t' * tabdepth, print "} else {" if right[node] != -1: recurse (left, right, threshold, features,right[node], tabdepth+1) print '\t' * tabdepth, print "}" else: print '\t' * tabdepth, print "return " + str(value[node]) recurse(left, right, threshold, features, 0)
修改Zelazny7的代码来从决策树中提取SQL。
# SQL from decision tree def get_lineage(tree, feature_names): left = tree.tree_.children_left right = tree.tree_.children_right threshold = tree.tree_.threshold features = [feature_names[i] for i in tree.tree_.feature] le='<=' g ='>' # get ids of child nodes idx = np.argwhere(left == -1)[:,0] def recurse(left, right, child, lineage=None): if lineage is None: lineage = [child] if child in left: parent = np.where(left == child)[0].item() split = 'l' else: parent = np.where(right == child)[0].item() split = 'r' lineage.append((parent, split, threshold[parent], features[parent])) if parent == 0: lineage.reverse() return lineage else: return recurse(left, right, parent, lineage) print 'case ' for j,child in enumerate(idx): clause=' when ' for node in recurse(left, right, child): if len(str(node))<3: continue i=node if i[1]=='l': sign=le else: sign=g clause=clause+i[3]+sign+str(i[2])+' and ' clause=clause[:-4]+' then '+str(j) print clause print 'else 99 end as clusters'