从零开始搭建决策树——手撕CART算法(C++)

CART决策树的基本原理见CART决策树原理

本文的C++代码基于C++ 20标准(不包含C++ modules),对于之前的标准,可能需要做一些适配。

CART分类树和回归树的内容各自在一个类中,分类树为CartClassifier类,回归树为CartRegression类。

数据结构设计

二叉树设计

// 二叉树结点
struct BinTreeNode
{
    std::string threshold_str_;
    double threshold_ = -1;
    std::string feature_name_;
    std::shared_ptr<BinTreeNode> left_ = nullptr;
    std::shared_ptr<BinTreeNode> right_ = nullptr;

    [[nodiscard]] std::shared_ptr<BinTreeNode> copy() const
    {
        auto node = std::make_shared<BinTreeNode>();
        node->threshold_ = threshold_;
        node->threshold_str_ = threshold_str_;
        node->feature_name_ = feature_name_;
        if (left_)
            node->left_ = left_->copy();
        if (right_)
            node->right_ = right_->copy();
        return node;
    }
};

copy模块用于二叉树结点的深复制,包括复制本身及其所有的子结点。

结点信息设计

struct Info
{
    std::shared_ptr<BinTreeNode> tree_;
    size_t num_leaf_ = 0;
    double a = 0;
    std::pair<bool, std::string> key_str_{};
    std::pair<bool, double> key_{};
};

实际上结点信息可以直接存储到二叉树结点BinTreeNode中。分开是为了保证代码的语义清晰,易于理解。

分类树

训练

/**
 * @brief
 * 训练决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param feature_names 属性名
 * @return 生成的决策树
 */
shared_ptr<BinTreeNode> CartClassifier::train(const vector<vector<string>>& X, const vector<string>& y, const vector<string>& feature_names)
{
    feature_names_ = feature_names;
    // 创建CART决策树
    tree_ = create_tree(X, y);
    return tree_;
}

训练函数通过传递常量引用形参,防止训练集和属性集被篡改。如需要修改,可以在函数内部设置副本,针对副本进行修改。create_tree是创建CART分类树的核心函数。

/**
 * @brief
 * 创建树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return 训练好的决策树
 */
shared_ptr<BinTreeNode> CartClassifier::create_tree(const vector<vector<string>>& X, const vector<string>& y)
{
    // 若X中样本全属于同一类别C,则停止划分
    auto tree = make_shared<BinTreeNode>();
    if (unordered_set(y.begin(), y.end()).size() == 1)
    {
        tree->threshold_str_ = y.front();
        return tree;
    }
    // 若节点样本数小于min_samples_split,或者属性集上的取值均相同
    if (y.size() <= min_samples_split_ || set(X.begin(), X.end()).size() == 1)
    {
        tree->threshold_str_ = majority_y(y);
        return tree;
    }
    // 按照“基尼增益”,从属性值中选择最优分裂属性的最优切分点
    auto [best_split_point, best_feature_index] = choose_best_point_to_split(X, y);
    const string best_feature_name = feature_names_[best_feature_index];
    // 根据最优切分点,进行子树的划分
    vector<vector<string>> sub_X1, sub_X2;
    vector<string> sub_y1, sub_y2;
    for (int i = 0; i < X.size(); i++)
        if (X[i][best_feature_index] == best_split_point)
        {
            sub_X1.emplace_back(X[i]);
            sub_y1.emplace_back(y[i]);
        }
        else
        {
            sub_X2.emplace_back(X[i]);
            sub_y2.emplace_back(y[i]);
        }
    tree->feature_name_ = best_feature_name;
    tree->threshold_str_ = best_split_point;
    tree->left_ = create_tree(sub_X1, sub_y1);
    tree->right_ = create_tree(sub_X2, sub_y2);
    return tree;
}

create_tree函数是一个递归创建决策树的过程。首先判断三种递归中止条件:

  • X中样本全部属于同一类别;
  • 当前节点样本数小于min_samples_split_
  • 属性集上的取值均相同

若满足终止条件,则选择中最多的类别作为结果返回。若未满足终止条件,依次执行以下步骤:

  1. 根据基尼指数从属性值中选择最优分裂属性的最优切分点,具体过程如choose_best_point_to_split函数所示;
  2. 根据最优切分点对子树进行划分;
  3. 对于其子树再继续执行create_tree函数完成划分过程。
/**
 * @brief
 * 统计每个类别出现的次数,返回出现次数最大的类别ID
 * @param y 目标变量集合
 * @return 出现次数最大的类别
 */
string CartClassifier::majority_y(const vector<string>& y)
{
    // 统计y中的目标变量值的个数
    unordered_map<string, int> y_count;
    for (const string& v : y)
    {
        if (!y_count.contains(v))
            y_count[v] = 0;
        ++y_count[v];
    }
    return ranges::max_element(y_count, [](const pair<string, int>& a, const pair<string, int>& b) { return a.second < b.second; })->first;
}

majority_y用于计算节点中出现次数最多的类别。包含以下步骤:

  1. 初始化一个空映射;
  2. 遍历y并对其元素进行计数;
  3. 从映射中查找出现次数最多的类别。
/**
 * @brief
 * 选择最优切分点
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return 最优切分点和最优切分点所在属性的索引
 */
pair<string, int> CartClassifier::choose_best_point_to_split(const vector<vector<string>>& X, const vector<string>& y)
{
    string best_split_point;
    int best_feature_index = -1;
    double best_gini_index = numeric_limits<double>::infinity();
    const size_t num_feature = X[0].size(); // 属性的个数
    for (int i = 0; i < num_feature; i++) // 遍历每个属性
    {
        // 得到某个属性下的所有值,即某列,并去重,得到无重复的属性特征值
        unordered_set<string> split_points;
        for (const vector<string>& x : X)
            split_points.emplace(x[i]);
        for (const string& split_point : split_points) // 计算各个候选切分点的基尼不纯度
        {
            vector<string> sub_y_left, sub_y_right;
            for (int j = 0; j < X.size(); j++)
                if (X[j][i] == split_point)
                    sub_y_left.emplace_back(y[j]);
                else
                    sub_y_right.emplace_back(y[j]);
            // 计算左子树的基尼不纯度
            const double gini_impurity_left = cal_gini_impurity(sub_y_left);
            // 计算右子树的基尼不纯度
            const double gini_impurity_right = cal_gini_impurity(sub_y_right);
            // 计算该切分点的基尼指数
            const double pro_left = static_cast<double>(sub_y_left.size()) / static_cast<double>(y.size()), pro_right = static_cast<double>(sub_y_right.size()) / static_cast<double>(y.size());
            if (const double gini_index = cal_gini_index(pro_left, pro_right, gini_impurity_left, gini_impurity_right); best_gini_index > gini_index) // 取基尼指数最大的属性索引和切分点
            {
                best_gini_index = gini_index;
                best_feature_index = i;
                best_split_point = split_point;
            }
        }
    }
    return {best_split_point, best_feature_index};
}

choose_best_point_to_split是CART分类树中最核心的函数,该函数负责选择最优切分点。根据前面的理论推导,该函数的目的是计算取得最大基尼增益的属性值。该函数遍历每个属性的每个属性值,根据是否等于属性值(二分类问题)将数据集分割到左右子树,依次计算左右子树的基尼不纯度,以及左右子树中数据样本在总样本中占的比例​,并且将代入cal_gini_index函数中计算基尼指数。最后选出具有最小基尼指数的属性值,作为当前节点的最优切分点,并返回最优切分点和最优分裂属性索引。

/**
 * @brief
 * 计算数据集的基尼不纯度
 * @param y 目标变量集合
 * @return 基尼不纯度 double
 */
double CartClassifier::cal_gini_impurity(const vector<string>& y)
{
    // 统计y中的目标变量值的个数
    unordered_map<string, int> y_count;
    for (const string& v : y)
    {
        if (!y_count.contains(v))
            y_count[v] = 0;
        ++y_count[v];
    }
    // 计算基尼不纯度
    double gini_impurity = 1;
    const auto num_samples = static_cast<double>(y.size());
    for (const int& k : y_count | views::values)
    {
        const double prob = k / num_samples;
        gini_impurity -= prob * prob;
    }
    return gini_impurity;
}

cal_gini_impurity用于计算基尼不纯度,包含以下步骤:

  1. 分析导入的数据集的最后一列(一般默认为数据类别)​,根据不同类别按出现次数统计到分类字典中;
  2. 遍历该字典,根据公式用1减去不同的类分布概率的平方和,得到最终的基尼不纯度。
/**
 * @brief
 * 计算基尼指数
 * @param pro_left 左子树比例
 * @param pro_right 右子树比例
 * @param gini_impurity_left 左子树的基尼不纯度
 * @param gini_impurity_right 右子树的基尼不纯度
 * @return 基尼指数 double
 */
double CartClassifier::cal_gini_index(const double pro_left, const double pro_right, const double gini_impurity_left, const double gini_impurity_right)
{
    return pro_left * gini_impurity_left + pro_right * gini_impurity_right;
}

cal_gini_index通过公式计算基尼指数。

预测

/**
 * @brief
 * 使用决策树进行预测
 * @param X 测试集属性值
 * @return 预测值
 */
vector<string> CartClassifier::predict(const vector<vector<string>>& X)
{
    vector<string> y_preds;
    for (const vector<string>& x : X)
        y_preds.emplace_back(classify(tree_, x));
    return y_preds;
}

遍历测试集X的每个样本,使用classify函数分别对其进行预测,最终返回拼接好的预测结果。

/**
 * @brief
 * 分类预测
 * @param tree 训练好的CART树
 * @param x 待分类样本
 * @return 预测类
 */
string CartClassifier::classify(const shared_ptr<BinTreeNode>& tree, const vector<string>& x)
{
    const string& first_str = tree->feature_name_; // 根节点
    const size_t feature_index = distance(feature_names_.begin(), ranges::find(feature_names_, first_str));
    const string& current_value = x[feature_index];
    if (tree->left_ && current_value == tree->threshold_str_)
        return classify(tree->left_, x);
    if (tree->right_ && current_value != tree->threshold_str_)
        return classify(tree->right_, x);
    return tree->threshold_str_;
}

通过调用classify进行预测分类。参数tree的根节点代表属性,根节点的左右孩子节点代表属性的取值及路由方向。在递归遍历过程中,从根节点开始,递归遍历CART分类树,最终路由到某个叶子节点,叶子节点上的值即为该决策树的预测结果。

剪枝

/**
 * @brief
 * 代价复杂度剪枝CCP
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return 剪枝后的决策树集合
 */
vector<shared_ptr<BinTreeNode>> CartClassifier::pruning(const vector<vector<string>>& X, const vector<string>& y)
{
    // 递归计算对当前树的每个子树的g(ti),挑选最小的g(ti)进行剪枝,得到新的T,最终得到n个T
    return split_n_best_trees(X, y);
}

函数pruning根据不同的区间生成不同剪枝程度的决策树集合。集合中越后面的决策树,剪枝程度越高。

/**
 * @brief
 * 根据g(ti)生成n个误差最小的树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return n个误差最小的树
 */
vector<shared_ptr<BinTreeNode>> CartClassifier::split_n_best_trees(const vector<vector<string>>& X, const vector<string>& y)
{
    vector<shared_ptr<BinTreeNode>> trees;
    shared_ptr<BinTreeNode> tree = tree_->copy();
    while (tree)
        if (shared_ptr<BinTreeNode> best_tree = split_1_best_trees(tree, X, y))
        {
            trees.emplace_back(best_tree);
            tree = best_tree->copy();
        }
        else
            tree = nullptr;
    return trees;
}

split_n_best_trees函数通过调用split_1_best_trees函数递归生成棵预测误差最小的树,每一次递归的初始树均为上一次递归得到的最优剪枝树。为了在递归过程中不破坏上一轮得到的最优剪枝树,使用了深拷贝。

/**
 * @brief
 * 计算α值,选出α值最小的剪枝树
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return α值最小的剪枝树
 */
shared_ptr<BinTreeNode> CartClassifier::split_1_best_trees(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X, const vector<string>& y)
{
    // 构建节点信息总集合
    vector<Info> infoSet;
    // 计算数据集长度
    const size_t NT = X.size();
    // 计算误差增加率,并生成信息集合
    calErrorRatio(tree, X, y, NT, infoSet);
    if (infoSet.empty())
        return nullptr;
    // a的比较基准值
    double baseValue = 1;
    int bestNode = 0;
    for (int i = 0; i < infoSet.size(); i++)
        if (infoSet[i].a < baseValue)
        {
            baseValue = infoSet[i].a;
            bestNode = i;
        }
        else if (infoSet[i].a == baseValue && infoSet[i].num_leaf_ > infoSet[bestNode].num_leaf_)
            bestNode = i;
    return prunBranch(tree, X, y, infoSet[bestNode]);
}

函数split_1_best_tree负责递归计算值,并且选出值最小的剪枝树。当前树的深度大于1时,开始进行CCP的迭代剪枝。在每次迭代内部,对每个分支节点进行的计算,并选取最小值对应的子树进行剪枝。如果求得的最小对应的子树有多个,则优先选取节点数目最多的子树作为修剪的对象。

/**
 * @brief
 * 计算非叶节点误差增加率
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param NT 数据集总样本数目
 * @param infoSet 所有节点的信息总集合
 * @return 各个节点的信息集
 */
Info CartClassifier::calErrorRatio(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X, const vector<string>& y, const size_t NT, vector<Info>& infoSet)
{
    const string_view firstFeat = tree->feature_name_;
    const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
    if (tree->left_ && (tree->left_->left_ || tree->left_->right_))
    {
        // 划分数据集
        vector<vector<string>> sub_X;
        vector<string> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] == tree->threshold_str_)
            {
                // 取第i行进subData
                // 相当于把label特征取值剔除,将其他特征取值输出
                // 将每个符合条件的特征列表,组成列表集合
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        Info info = calErrorRatio(tree->left_, sub_X, sub_y, NT, infoSet);
        // 在节点信息集中,增加分类前特征
        info.key_str_ = {true, tree->threshold_str_};
        infoSet.emplace_back(info);
    }
    if (tree->right_ && (tree->right_->left_ || tree->right_->right_))
    {
        // 划分数据集
        vector<vector<string>> sub_X;
        vector<string> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] != tree->threshold_str_)
            {
                // 取第i行进subData
                // 相当于把label特征取值剔除,将其他特征取值输出
                // 将每个符合条件的特征列表,组成列表集合
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        Info info = calErrorRatio(tree->right_, sub_X, sub_y, NT, infoSet);
        // 在节点信息集中,增加分类前特征
        info.key_str_ = {false, tree->threshold_str_};
        infoSet.emplace_back(info);
    }
    // 计算节点误差率
    const double Ct = static_cast<double>(nodeError(y)) / static_cast<double>(NT);
    // 计算子树误差率
    const double CTt = static_cast<double>(leafError(tree, X, y)) / static_cast<double>(NT);
    // 计算叶节点数目
    const size_t Nt = getNumLeaf(tree);
    const double a = Nt == 1 ? 2 : (Ct - CTt) / static_cast<double>(Nt - 1);
    return {tree, Nt, a};
}

每次迭代中的计算,也就是calErrorRatio函数。该函数主要计算节点的误差率​、节点对应子树的误差率​、子树叶子节点的数目的计算采用递归的方法,最终将所有info合并成节点信息集合。

/**
 * @brief
 * 计算非叶节点的误差
 * @param y 训练集目标变量
 * @return 误差
 */
size_t CartClassifier::nodeError(const vector<string>& y)
{
    // 找到数量最多的类别
    string majorClass = majority_y(y);
    // 游历数据集每个元素,找出正确样本个数,如果不一致,错误加1
    return ranges::count_if(y, [&majorClass](const string& v) { return v != majorClass; });
}

/**
 * @brief
 * 计算叶节点的误差
 * @param tree 生成的决策树
 * @param X
 * @param y 训练集目标变量
 * @return 误差
 */
size_t CartClassifier::leafError(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X, const vector<string>& y)
{
    size_t error = 0;
    for (int i = 0; i < X.size(); i++)
        if (classify(tree, X[i]) != y[i])
            ++error;
    return error;
}

/**
 * @brief
 * 获取叶节点数量
 * @param tree 决策树
 * @return 返回树的叶节点
 */
size_t CartClassifier::getNumLeaf(const shared_ptr<BinTreeNode>& tree)
{
    size_t numLeafs = 0;
    if (tree->left_)
        numLeafs += getNumLeaf(tree->left_);
    if (tree->right_)
        numLeafs += getNumLeaf(tree->right_);
    if (!tree->left_ && !tree->right_)
        ++numLeafs;
    return numLeafs;
}
/**
 * @brief
 * 根据误差增加率,剪掉子树
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param infoBran 需剪掉的子树信息集
 * @return 剪枝后的决策树
 */
shared_ptr<BinTreeNode> CartClassifier::prunBranch(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X, const vector<string>& y, const Info& infoBran)
{
    const string_view firstFeat = tree->feature_name_;
    const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
    if (tree->left_)
    {
        // 划分数据集
        vector<vector<string>> sub_X;
        vector<string> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] == tree->threshold_str_)
            {
                // 取第i行进subData
                // 相当于把label特征取值剔除,将其他特征取值输出
                // 将每个符合条件的特征列表,组成列表集合
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        // 找到数量最多的类别
        const string majorClass = majority_y(sub_y);
        // 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
        if (infoBran.key_str_.first && infoBran.key_str_.second == tree->threshold_str_ && tree->left_ == infoBran.tree_)
        {
            // 剪掉子树,即返回最大类
            tree->left_ = make_shared<BinTreeNode>();
            tree->left_->threshold_str_ = majorClass;
            return tree;
        }
        // 如果不相同,继续向下寻找
        tree->left_ = prunBranch(tree->left_, sub_X, sub_y, infoBran);
    }
    if (tree->right_)
    {
        // 划分数据集
        vector<vector<string>> sub_X;
        vector<string> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] != tree->threshold_str_)
            {
                // 取第i行进subData
                // 相当于把label特征取值剔除,将其他特征取值输出
                // 将每个符合条件的特征列表,组成列表集合
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        // 找到数量最多的类别
        const string majorClass = majority_y(sub_y);
        // 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
        if (!infoBran.key_str_.first && infoBran.key_str_.second == tree->threshold_str_ && tree->right_ == infoBran.tree_)
        {
            // 剪掉子树,即返回最大类
            tree->right_ = make_shared<BinTreeNode>();
            tree->right_->threshold_str_ = majorClass;
            return tree;
        }
        // 如果不相同,继续向下寻找
        tree->right_ = prunBranch(tree->right_, sub_X, sub_y, infoBran);
    }
    return tree;
}

应用注意事项

  1. 为方便理解,代码仅考虑了离散字符串的分类,并未考虑其他离散值和连续值的分类,实际生产过程可能需要补充;
  2. 原则上来说,代码数据集中的字符串均需要通过编码(分类算法中编码无限制),以提升效率。为方便理解,本文章使用原始字符串,不影响结果;
  3. 代码中的feature_name仅作画图需要,实际生产如无该需求,可以去掉该变量;
  4. 对于CCP误差的计算,scikit-learn使用基尼不纯度进行代替,因其不用每次使用预测计算,提高了效率。但基尼不纯度与误差之间仅具有相关性,无法通过基尼不纯度推导出误差,仅用作近似计算;
  5. 代码未考虑缺失值的处理;
  6. 代码没有适配多线程场景;
  7. 其他可能的算法时空复杂度的优化。

回归树

训练

CartRegressor的创建和训练过程与CartClassifier类似。最重要的区别在于模型训练时切分点的选取。

/**
 * @brief
 * 训练决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param feature_names 属性名
 * @return 生成的决策树
 */
shared_ptr<BinTreeNode> CartRegressor::train(const vector<vector<double>>& X, const vector<double>& y, const vector<string>& feature_names)
{
    feature_names_ = feature_names;
    tree_ = create_tree(X, y);
    return tree_;
}

train的入参类型发生了变化,这是因为回归树使用的是连续类型数据。

/**
 * @brief
 * 创建树
 * @param X 映射后的数值属性集
 * @param y 映射后的数值目标变量集
 * @return 训练好的决策树
 */
shared_ptr<BinTreeNode> CartRegressor::create_tree(const vector<vector<double>>& X, const vector<double>& y)
{
    // 若X中样本全属于同一类别C,则停止划分
    auto tree = make_shared<BinTreeNode>();
    if (unordered_set(y.begin(), y.end()).size() == 1)
    {
        tree->threshold_ = y.front();
        return tree;
    }
    // 若节点样本数小于min_samples_split,或者属性集上的取值均相同
    if (y.size() <= min_samples_split_ || set(X.begin(), X.end()).size() == 1)
    {
        tree->threshold_ = accumulate(y.begin(), y.end(), 0.) / static_cast<double>(y.size());
        return tree;
    }
    // 按照“平方误差最小”,从feature_names中选择最优切分点
    auto [best_split_point, best_feature_index] = choose_best_point_to_split(X, y);
    const string_view best_feature_name = feature_names_[best_feature_index];
    // 根据最优切分点,进行子树的划分
    vector<vector<double>> sub_X1, sub_X2;
    vector<double> sub_y1, sub_y2;
    for (int i = 0; i < X.size(); i++)
        if (X[i][best_feature_index] <= best_split_point)
        {
            sub_X1.emplace_back(X[i]);
            sub_y1.emplace_back(y[i]);
        }
        else
        {
            sub_X2.emplace_back(X[i]);
            sub_y2.emplace_back(y[i]);
        }
    tree->feature_name_ = best_feature_name;
    tree->threshold_ = best_split_point;
    tree->left_ = create_tree(sub_X1, sub_y1);
    tree->right_ = create_tree(sub_X2, sub_y2);
    return tree;
}

在函数create_tree中,主要有3处与分类树不同:

  1. 当满足递归终止条件“节点样本数小于min_samples_split_”时,返回的预测值是该集合中所有目标变量的平均值;
  2. choose_best_point_to_split函数中,在回归树中采用“平方误差最小”的原则来选择最优切分点;
  3. 使用最优属性和最优切分点划分数据集时相较分类树(处理匹配字符串“”和“”的代码逻辑)做略微调整。
/**
 * @brief
 * 选择最优切分点
 * @param X 映射后的数值属性集
 * @param y 属性名称
 * @return 最优切分点和最优切分点所在属性的索引
 */
pair<double, int> CartRegressor::choose_best_point_to_split(const vector<vector<double>>& X, const vector<double>& y)
{
    double best_split_point = 0, best_loss_all = numeric_limits<double>::infinity();
    int best_feature_index = -1;
    const size_t num_feature = X[0].size(); // 属性的个数
    for (int i = 0; i < num_feature; ++i) // 遍历每个属性
    {
        // 得到某个属性下的所有值,即某列,并去重,得到无重复的属性特征值
        set<double> unique_feature_value;
        vector<double> split_points;
        for (const vector<double>& x : X)
            unique_feature_value.emplace(x[i]);
        auto lit = unique_feature_value.begin(), rit = lit;
        ++rit;
        while (rit != unique_feature_value.end())
        {
            split_points.emplace_back((*lit + *rit) / 2);
            ++lit;
            ++rit;
        }
        // 计算各个候选切分点的损失函数
        for (const double split_point : split_points)
        {
            vector<double> sub_y_left, sub_y_right;
            for (int j = 0; j < X.size(); j++)
                if (X[j][i] <= split_point)
                    sub_y_left.emplace_back(y[j]);
                else
                    sub_y_right.emplace_back(y[j]);
            const double sub_y_left_mean = accumulate(sub_y_left.begin(), sub_y_left.end(), 0.) / static_cast<double>(sub_y_left.size()), sub_y_right_mean = accumulate(sub_y_right.begin(), sub_y_right.end(), 0.) / static_cast<double>(sub_y_right.size());
            double loss_left = 0, loss_right = 0;
            // 计算左子树的损失函数
            for (const double j : sub_y_left)
                loss_left += pow(j - sub_y_left_mean, 2);
            // 计算右子树的损失函数
            for (const double j : sub_y_right)
                loss_right += pow(j - sub_y_right_mean, 2);
            // 计算该切分点的总损失函数
            // 取损失函数最小时的属性索引和切分点
            if (const double loss_all = loss_left + loss_right; best_loss_all > loss_all)
            {
                best_loss_all = loss_all;
                best_feature_index = i;
                best_split_point = split_point;
            }
        }
    }
    return {best_split_point, best_feature_index};
}

choose_best_point_to_split遍历所有属性值时,回归树中不再计算基尼不纯度和基尼增益,而是针对回归问题计算损失函数。分别计算了使用当前切分点划分的左右子树的残差平方和,再计算左右子树的总残差平方和。最后选出取得最小损失函数的切分点和属性索引,作为最优切分点和最优分裂属性。

预测

/**
 * @brief
 * 使用决策树进行预测
 * @param X 测试集属性值
 * @return 预测值
 */
vector<double> CartRegressor::predict(const vector<vector<double>>& X)
{
    vector<double> y_preds;
    for (const vector<double>& x : X)
        y_preds.emplace_back(regression(tree_, x));
    return y_preds;
}

/**
 * @brief
 * 回归预测
 * @param tree 训练好的树
 * @param x 待分类样本
 * @return 预测类
 */
double CartRegressor::regression(const shared_ptr<BinTreeNode>& tree, const vector<double>& x)
{
    const string& first_str = tree->feature_name_; // 根节点
    const size_t feature_index = distance(feature_names_.begin(), ranges::find(feature_names_, first_str));
    const double current_value = x[feature_index];
    if (tree->left_ && current_value <= tree->threshold_)
        return regression(tree->left_, x);
    if (tree->right_ && current_value > tree->threshold_)
        return regression(tree->right_, x);
    return tree->threshold_;
}

由于CART回归树与分类树的预测过程几乎完全相同,在此不做赘述。

剪枝

/**
 * @brief
 * 代价复杂度剪枝CCP
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return 剪枝后的决策树集合
 */
vector<shared_ptr<BinTreeNode>> CartRegressor::pruning(const vector<vector<double>>& X, const vector<double>& y)
{
    // 递归计算对当前树的每个子树的g(ti),挑选最小的g(ti)进行剪枝,得到新的T,最终得到n个T
    return split_n_best_trees(X, y);
}

/**
 * @brief
 * 根据g(ti)生成n个误差最小的树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return n个误差最小的树
 */
vector<shared_ptr<BinTreeNode>> CartRegressor::split_n_best_trees(const vector<vector<double>>& X, const vector<double>& y)
{
    vector<shared_ptr<BinTreeNode>> trees;
    shared_ptr<BinTreeNode> tree = tree_->copy();
    while (tree)
        if (shared_ptr<BinTreeNode> best_tree = split_1_best_trees(tree, X, y))
        {
            trees.emplace_back(best_tree);
            tree = best_tree->copy();
        }
        else
            tree = nullptr;
    return trees;
}

/**
 * @brief
 * 计算α值,选出α值最小的剪枝树
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return α值最小的剪枝树
 */
shared_ptr<BinTreeNode> CartRegressor::split_1_best_trees(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X, const vector<double>& y)
{
    // 构建节点信息总集合
    vector<Info> infoSet;
    // 计算数据集长度
    const size_t NT = X.size();
    // 计算误差增加率,并生成信息集合
    calErrorRatio(tree, X, y, NT, infoSet);
    if (infoSet.empty())
        return nullptr;
    // a的比较基准值
    double baseValue = 1;
    int bestNode = 0;
    for (int i = 0; i < infoSet.size(); i++)
        if (infoSet[i].a < baseValue)
        {
            baseValue = infoSet[i].a;
            bestNode = i;
        }
        else if (infoSet[i].a == baseValue && infoSet[i].num_leaf_ > infoSet[bestNode].num_leaf_)
            bestNode = i;
    return prunBranch(tree, X, y, infoSet[bestNode]);
}

/**
 * @brief
 * 计算非叶节点误差增加率
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param NT 数据集总样本数目
 * @param infoSet 所有节点的信息总集合
 * @return 各个节点的信息集
 */
Info CartRegressor::calErrorRatio(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X, const vector<double>& y, size_t NT, vector<Info>& infoSet)
{
    const string_view firstFeat = tree->feature_name_;
    const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
    if (tree->left_ && (tree->left_->left_ || tree->left_->right_))
    {
        // 划分数据集
        vector<vector<double>> sub_X;
        vector<double> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] <= tree->threshold_)
            {
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        Info info = calErrorRatio(tree->left_, sub_X, sub_y, NT, infoSet);
        // 在节点信息集中,增加分类前特征
        info.key_ = {true, tree->threshold_};
        infoSet.emplace_back(info);
    }
    if (tree->right_ && (tree->right_->left_ || tree->right_->right_))
    {
        // 划分数据集
        vector<vector<double>> sub_X;
        vector<double> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] > tree->threshold_)
            {
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        Info info = calErrorRatio(tree->right_, sub_X, sub_y, NT, infoSet);
        // 在节点信息集中,增加分类前特征
        info.key_ = {false, tree->threshold_};
        infoSet.emplace_back(info);
    }
    // 计算节点误差率
    const double Rt = static_cast<double>(nodeError(y)) / static_cast<double>(NT);
    // 计算子树误差率
    const double RTt = static_cast<double>(leafError(tree, X, y)) / static_cast<double>(NT);
    // 计算叶节点数目
    const size_t Nt = getNumLeaf(tree);
    const double a = Nt == 1 ? 2 : (Rt - RTt) / static_cast<double>(Nt - 1);
    return {tree, Nt, a};
}

/**
 * @brief
 * 计算非叶节点的误差
 * @param y 训练集目标变量
 * @return 误差
 */
size_t CartRegressor::nodeError(const vector<double>& y)
{
    // 计算节点的平方误差
    const double mean_y = accumulate(y.begin(), y.end(), 0.) / static_cast<double>(y.size());
    size_t error = 0;
    for (const double& val : y)
        error += static_cast<size_t>(pow(val - mean_y, 2));
    return error;
}

/**
 * @brief
 * 计算叶节点的误差
 * @param tree 生成的决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return 误差
 */
size_t CartRegressor::leafError(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X, const vector<double>& y)
{
    size_t error = 0;
    for (int i = 0; i < X.size(); i++)
    {
        const double pred = regression(tree, X[i]);
        error += static_cast<size_t>(pow(pred - y[i], 2));
    }
    return error;
}

/**
 * @brief
 * 获取叶节点数量
 * @param tree 决策树
 * @return 返回树的叶节点
 */
size_t CartRegressor::getNumLeaf(const shared_ptr<BinTreeNode>& tree)
{
    size_t numLeafs = 0;
    if (tree->left_)
        numLeafs += getNumLeaf(tree->left_);
    if (tree->right_)
        numLeafs += getNumLeaf(tree->right_);
    if (!tree->left_ && !tree->right_)
        ++numLeafs;
    return numLeafs;
}

/**
 * @brief
 * 根据误差增加率,剪掉子树
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param infoBran 需剪掉的子树信息集
 * @return 剪枝后的决策树
 */
shared_ptr<BinTreeNode> CartRegressor::prunBranch(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X, const vector<double>& y, const Info& infoBran)
{
    const string_view firstFeat = tree->feature_name_;
    const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
    if (tree->left_)
    {
        // 划分数据集
        vector<vector<double>> sub_X;
        vector<double> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] <= tree->threshold_)
            {
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        // 计算该分支的平均值
        const double mean_val = accumulate(sub_y.begin(), sub_y.end(), 0.) / static_cast<double>(sub_y.size());
        // 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
        if (infoBran.key_.first && abs(infoBran.key_.second - tree->threshold_) < 1e-9 && tree->left_ == infoBran.tree_)
        {
            // 剪掉子树,即返回平均值
            tree->left_ = make_shared<BinTreeNode>();
            tree->left_->threshold_ = mean_val;
            return tree;
        }
        // 如果不相同,继续向下寻找
        tree->left_ = prunBranch(tree->left_, sub_X, sub_y, infoBran);
    }
    if (tree->right_)
    {
        // 划分数据集
        vector<vector<double>> sub_X;
        vector<double> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] > tree->threshold_)
            {
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        // 计算该分支的平均值
        const double mean_val = accumulate(sub_y.begin(), sub_y.end(), 0.) / static_cast<double>(sub_y.size());
        // 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
        if (!infoBran.key_.first && abs(infoBran.key_.second - tree->threshold_) < 1e-9 && tree->right_ == infoBran.tree_)
        {
            // 剪掉子树,即返回平均值
            tree->right_ = make_shared<BinTreeNode>();
            tree->right_->threshold_ = mean_val;
            return tree;
        }
        // 如果不相同,继续向下寻找
        tree->right_ = prunBranch(tree->right_, sub_X, sub_y, infoBran);
    }
    return tree;
}

回归树的剪枝与分类树类似,不同点在于回归树计算误差使用的是均方差。

应用注意事项

  1. 代码中的feature_name仅作画图需要,实际生产如无该需求,可以去掉该变量;
  2. 代码未考虑缺失值的处理;
  3. 分类树和回归树中的CCP算法,仅在误差计算中有区别。分类树中可以使用基尼系数或误分类率(从效率层面,推荐使用基尼系数),回归树中使用均方差;
  4. 代码没有适配多线程场景;
  5. 其他可能的算法时空复杂度的优化。
全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务