【笔试刷题】美团-2026.03.14-算法岗-改编真题

✅ 春招备战指南 ✅

💡 学习建议:

  • 先尝试独立解题
  • 对照解析查漏补缺

🧸 题面描述背景等均已深度改编,做法和题目本质基本保持一致。

🍹 感谢各位朋友们的订阅,你们的支持是我们创作的最大动力

🌸 目前本专栏已经上线200+套真题改编解析,后续会持续更新的

春秋招笔试机考招合集 -> 互联网必备刷题宝典🔗

美团-2026.03.14-算法岗

这套 3.14 美团算法岗题的区分度更明显。第一题仍然是结论型热身,第二题开始转到概率模型实现,第三题要求把树上“最近更大祖先”变形成额外跳边最短路,最后一题再回到删边离线并查集,整体更偏“建模 + 数据结构”。

题目一:区间平方因子扫描

这一题和研发岗的首题同源,本质仍然是完全平方数计数。只要抓住“因子成对、平方数例外”这个结论,就能直接转成区间开方统计。

题目二:词频朴素贝叶斯判别

这题的难点不在公式本身,而在于把题面给出的多项式朴素贝叶斯流程稳定实现出来。拉普拉斯平滑、对数后验和 JSON 输入输出这三件事,一件都不能漏。

题目三:祖先跳边最短路

原树上每个点都会向最近的更大权值祖先尝试补一条边,真正的核心是如何高效找到这个祖先。把 DFS 序和“当前根到点路径上的活跃祖先”维护好之后,再做一次 BFS 就能求全点最短路。

题目四:断边图连通峰值

这题和研发岗最后一题同源,关键还是逆序删边。只要方向反过来,并查集就能自然维护连通块最大点权,整题也就从“动态图”变成了标准离线题。

01. 区间平方因子扫描

问题描述

小美很喜欢研究数字的因子。

如果正整数 p 可以整除正整数 x,那么称 px 的一个因子。例如 12 的因子有 1, 2, 3, 4, 6, 12

现在给定一个区间 [l, r],请你计算这个区间里有多少个数的因子数量是奇数。

输入格式

输入一行,包含两个整数 lr

输出格式

输出一个整数,表示区间 [l, r] 中因子数量为奇数的数的个数。

样例输入 1

1 1

样例输出 1

1

样例说明 1

区间中只有数字 1,它只有一个因子,因此答案是 1

样例输入 2

4 5

样例输出 2

1

样例说明 2

4 的因子有 1, 2, 4,一共 3 个;5 的因子有 1, 5,一共 2 个。

所以区间中只有 4 满足条件,答案是 1

数据范围

  • 1 <= l <= r <= 10^9

题解

这题最重要的不是枚举因子,而是先看“为什么因子数量会是奇数”。

通常情况下,因子都会成对出现:

  • 如果 dx 的因子,那么 x / d 也是 x 的因子。
  • 这两个因子一般是不同的,所以会一对一对地贡献。

只有一种情况例外:这两个因子恰好相等。

也就是:

d = x / d

等价于:

d^2 = x

这说明,只有当 x 是完全平方数时,因子个数才会是奇数。

于是题目就被转化成一个非常直接的问题:

  • 统计区间 [l, r] 中完全平方数的个数。

设:

  • cnt(r) 表示不超过 r 的完全平方数个数。

那么显然有:

cnt(r) = floor(sqrt(r))

因为 1^2, 2^2, 3^2, ..., floor(sqrt(r))^2 都不超过 r

最终答案就是:

floor(sqrt(r)) - floor(sqrt(l - 1))

实现时需要注意一个小细节:直接调用浮点 sqrt 之后,可能因为精度问题出现偏差,所以最好在取整后再向上、向下各修正几次。

时间复杂度是 O(1),空间复杂度是 O(1)

参考代码

  • Python
import math
import sys

input = lambda: sys.stdin.readline().strip()


def isqrt_floor(x: int) -> int:
    # Python 直接使用整数平方根,避免浮点误差。
    return math.isqrt(x)


def solve() -> None:
    l, r = map(int, input().split())

    # 区间内完全平方数个数 = 前缀个数做差。
    ans = isqrt_floor(r) - isqrt_floor(l - 1)
    print(ans)


if __name__ == "__main__":
    solve()
  • Cpp
#include <bits/stdc++.h>
using namespace std;

using int64 = long long;

static int64 isqrt_floor(int64 x) {
    // 先用长双精度开方,再做微调,避免边界误差。
    int64 y = sqrtl((long double)x);
    while ((y + 1) * (y + 1) <= x) {
        ++y;
    }
    while (y * y > x) {
        --y;
    }
    return y;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int64 l, r;
    cin >> l >> r;

    // 前缀平方数个数做差即可。
    cout << isqrt_floor(r) - isqrt_floor(l - 1) << '\n';
    return 0;
}
  • Java
import java.io.BufferedInputStream;
import java.io.IOException;

public class Main {
    private static class FastScanner {
        private final BufferedInputStream in = new BufferedInputStream(System.in);
        private final byte[] buf = new byte[1 << 16];
        private int len = 0;
        private int ptr = 0;

        private int read() throws IOException {
            if (ptr >= len) {
                len = in.read(buf);
                ptr = 0;
                if (len <= 0) {
                    return -1;
                }
            }
            return buf[ptr++];
        }

        long nextLong() throws IOException {
            int c;
            do {
                c = read();
            } while (c <= ' ' && c != -1);

            long sign = 1;
            if (c == '-') {
                sign = -1;
                c = read();
            }

            long val = 0;
            while (c > ' ') {
                val = val * 10 + c - '0';
                c = read();
            }
            return val * sign;
        }
    }

    private static long isqrtFloor(long x) {
        long y = (long) Math.sqrt(x);

        // 修正浮点误差,保证返回 floor(sqrt(x))。
        while ((y + 1) * (y + 1) <= x) {
            y++;
        }
        while (y * y > x) {
            y--;
        }
        return y;
    }

    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner();
        long l = fs.nextLong();
        long r = fs.nextLong();

        long ans = isqrtFloor(r) - isqrtFloor(l - 1);
        System.out.println(ans);
    }
}

02. 词频朴素贝叶斯判别

问题描述

请你实现一个多项式朴素贝叶斯二分类器,并在给定训练集后对测试集输出预测标签。

输入中包含两个字段:

  • train:二维数组。每一行的最后一个数是类别标签 y,满足 y ∈ {0, 1};前面的数都是该样本各个词的非负整数词频。
  • test:二维数组。每一行只包含词频特征,维度与训练集一致。

分类器按下面的规则工作:

  1. 使用拉普拉斯平滑,平滑参数固定为 1
  2. 先验概率:

pi_c = N_c / N

其中 N_c 表示类别 c 的训练样本数,N 表示训练样本总数。

  1. 条件概率采用多项式朴素贝叶斯模型:

P(w | c) = (cnt[c][w] + 1) / (sum_c + m)

其中:

  • cnt[c][w] 表示所有标签为 c 的训练样本中,第 w 个词出现的总频次;
  • sum_c 表示类别 c 下所有词频的总和;
  • m 表示特征维度。
  1. 对每个测试样本 x,计算对数后验分数:

score(c) = log(pi_c) + Σ x_w * log(P(w | c))

  1. score(1) >= score(0),输出 1;否则输出 0

输入格式

输入为一个合法 JSON 对象,格式如下:

{
  "train": [[f11, ..., f1m, y1], ..., [fn1, ..., fnm, yn]],
  "test": [[t11, ..., t1m], ..., [tk1, ..., tkm]]
}

保证:

  • traintest 中每一行的特征维度一致;
  • train[i] 的最后一个元素是标签;
  • 其余元素都是非负整数词频。

输出格式

输出一个 JSON 数组,按顺序给出所有测试样本的预测标签。

例如:

[0,1,0]

样例输入

{"train":[[2,0,0,0],[3,1,0,0],[0,0,2,1],[0,1,3,1]],"test":[[1,0,0],[0,1,2]]}

样例输出

[0,1]

数据范围

  • 原题未额外给出独立的数据范围上界。
  • 保证输入 JSON 合法,且所有行长度与题意一致。

题解

这题本质上是一个标准的多项式朴素贝叶斯二分类器实现题。

真正要做的事情其实很固定:

  1. 先按类别统计训练样本数,得到先验概率。
  2. 再按类别累计每个词的总词频,配合拉普拉斯平滑得到条件概率。
  3. 最后对每个测试样本分别计算两类的对数后验分数,谁大就判给谁。

为什么要用多项式模型

这里每个特征给出的不是“是否出现”,而是“出现了多少次”。

所以更自然的模型是多项式朴素贝叶斯:

  • w 个词在类别 c 下的条件概率由该词在这一类中出现的总次数决定;
  • 测试样本里某个词出现 x_w 次,就把对应的 log P(w | c) 累加 x_w 次。

这也就是公式:

score(c) = log(pi_c) + Σ x_w * log(P(w | c))

拉普拉斯平滑

如果某个词在某个类别中从未出现过,那么不平滑的话:

P(w | c) = 0

一旦测试样本里这个词频大于 0,整条概率链就会变成 0,这会让模型过于脆弱。

所以这里按题意固定使用拉普拉斯平滑 k = 1

P(w | c) = (cnt[c][w] + 1) / (sum_c + m)

其中:

  • cnt[c][w] 是类别 c 下第 w 个词的总频次;
  • sum_c 是类别 c 下所有词频之和;
  • m 是特征维度,也就是词表大小。

为什么要取对数

如果直接连乘很多个概率,小数会越来越小,容易出现数值下溢。

改成对数以后:

  • 连乘变连加;
  • 比较大小时结果不变;
  • 实现也更稳定。

因此我们维护:

  • score(0)
  • score(1)

最后按题意:

  • score(1) >= score(0),输出 1
  • 否则输出 0

边界情况

需要额外注意一个点:

  • 如果某个类别在训练集中一个样本都没有,那么它的先验概率就是 0,对应分数可以直接视为负无穷。

这样处理后,另一个真实存在的类别自然会胜出。

复杂度分析

设:

  • n 为训练样本数;
  • k 为测试样本数;
  • m 为特征维度。

那么:

  • 统计训练集需要 O(nm)
  • 预测测试集需要 O(km)

总时间复杂度为 O((n + k) * m),空间复杂度为 O(m)

参考代码

  • Python
import json
import math
import sys


def solve() -> None:
    data = json.loads(sys.stdin.read())
    train = data["train"]
    test = data["test"]

    feature_count = len(train[0]) - 1

    # 统计每一类的样本数、总词频,以及每个词在该类中的总出现次数。
    class_count = [0, 0]
    total_words = [0, 0]
    cnt = [[0] * feature_count for _ in range(2)]

    for row in train:
        label = row[-1]
        class_count[label] += 1
        for i, value in enumerate(row[:-1]):
            cnt[label][i] += value
            total_words[label] += value

    ans = []
    n = len(train)
    for row in test:
        score = [-10**100, -10**100]
        for c in range(2):
            if class_count[c] == 0:
                continue

            # 先验概率取对数。
            score[c] = math.log(class_count[c] / n)
            denom = total_words[c] + feature_count

            # 多项式 NB:词频 value 会把对应对数概率累加 value 次。
            for i, value in enumerate(row):
                if value == 0:
                    continue
                prob = (cnt[c][i] + 1) / denom
                score[c] += value * math.log(prob)

        ans.append(1 if score[1] >= score[0] else 0)

    sys.stdout.write(json.dumps(ans, separators=(",", ":")))


if __name__ == "__main__":
    solve()
  • Cpp
#include <bits/stdc++.h>
using namespace std;

struct Parser {
    string s;
    int p = 0;

    explicit Parser(string str) : s(std::move(str)) {}

    void skip() {
        while (p < (int)s.size() && isspace((unsigned char)s[p])) {
            ++p;
        }
    }

    void expect(char ch) {
        skip();
        if (p >= (int)s.size() || s[p] != ch) {
            throw runtime_error("invalid json");
        }
        ++p;
    }

    string parseString() {
        skip();
        expect('"');
        string res;
        while (p < (int)s.size()) {
            char c = s[p++];
            if (c == '"') break;
            if (c == '\\' && p < (int)s.size()) {
                char nxt = s[p++];
                if (nxt == '"' || nxt == '\\' || nxt == '/') res.push_back(nxt);
                else if (nxt == 'b') res.push_back('\b');
                else if (nxt == 'f') res.push_back('\f');
                else if (nxt == 'n') res.push_back('\n');
                else if (nxt == 'r') res.push_back('\r');
                else if (nxt == 't') res.push_back('\t');
                else throw runtime_error("unsupported escape");
            } else {
                res.push_back(c);
            }
        }
        return res;
    }

    int parseInt() {
        skip();
        int sign = 1;
        if (s[p] == '-') {
            sign = -1;
            ++p;
        }
        int val = 0;
        while (p < (int)s.size() && isdigit((unsigned char)s[p])) {
            val = val * 10 + (s[p] - '0');
            ++p;
        }
        return sign * val;
    }

    vector<int> parseIntArray() {
        expect('[');
        vector<int> arr;
        skip();
        if (p < (int)s.size() && s[p] == ']') {
            ++p;
            return arr;
        }
        while (true) {
            arr.push_back(parseInt());
            skip();
            if (s[p] == ']') {
                ++p;
                break;
            }
            expect(',');
        }
        return arr;
    }

    vector<vector<int>> parseMatrix() {
        expect('[');
        vector<vector<int>> mat;
        skip();
        if (p < (int)s.size() && s[p] == ']') {
            ++p;
            return mat;
        }
        while (true) {
            mat.push_back(parseIntArray());
            skip();
            if (s[p] == ']') {
                ++p;
                break;
            }
            expect(',');
        }
        return mat;
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    string input, line;
    while (getline(cin, line)) {
        input += line;
    }

    Parser parser(input);
    parser.expect('{');

    vector<vector<int>> train, test;
    while (true) {
        parser.skip();
        if (parser.p < (int)parser.s.size() && parser.s[parser.p] == '}') {
            ++parser.p;
            break;
        }

        string key = parser.parseString();
        parser.expect(':');
        if (key == "train") train = parser.parseMatrix();
        else if (key == "test") test = parser.parseMatrix();
        else throw runtime_error("unexpected key");

        parser.skip();
        if (parser.p < (int)parser.s.size() && parser.s[parser.p] == '}') {
            ++parser.p;
            break;
        }
        parser.expect(',');
    }

    int n = (int)train.size();
    int m = (int)train[0].size() - 1;

    vector<long long> classCnt(2, 0), totalWords(2, 0);
    vector<vector<long long>> cnt(2, vector<long long>(m, 0));

    for (const auto& row : train) {
        int label = row.back();
        ++classCnt[label];
        for (int i = 0; i < m; ++i) {
            cnt[label][i] += row[i];
            totalWords[label] += row[i];
        }
    }

    vector<int> ans;
    for (const auto& row : test) {
        long double score[2];
        for (int c = 0; c < 2; ++c) {
            if (classCnt[c] == 0) {
                score[c] = -1e100L;
                continue;
            }
            score[c] = log((long double)classCnt[c] / n);
            long double denom = totalWords[c] + m;
            for (int i = 0; i < m; ++i) {
                if (row[i] == 0) continue;
                long double prob = (cnt[c][i] + 1.0L) / denom;
                score[c] += row[i] * log(prob);
            }
        }
        ans.push_back(score[1] >= score[0] ? 1 : 0);
    }

    cout << "[";
    for (int i = 0; i < (int)ans.size(); ++i) {
        if (i) cout << ",";
        cout << ans[i];
    }
    cout << "]\n";
    return 0;
}
  • Java
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;

public class Main {
    static class Parser {
        String s;
        int p = 0;

        Parser(String s) {
            this.s = s;
        }

        void skip() {
            while (p < s.length() && Character.isWhitespace(s.charAt(p))) {
                p++;
            }
        }

        void expect(char ch) {
            skip();
            if (p >= s.length() || s.charAt(p) != ch) {
                throw new RuntimeException("invalid json");
            }
            p++;
        }

        String parseString() {
            skip();
            expect('"');
            StringBuilder sb = new StringBuilder();
            while (p < s.length()) {
                char c = s.charAt(p++);
                if (c == '"') {
                    break;
                }
                if (c == '\\' && p < s.length()) {
                    char nxt = s.charAt(p++);
                    if (nxt == '"' || nxt == '\\' || nxt == '/') sb.append(nxt);
                    else if (nxt == 'b') sb.append('\b');
                    else if (nxt == 'f') sb.append('\f');
                    else if (nxt == 'n') sb.append('\n');
                    else if (nxt == 'r') sb.append('\r');
                    else if (nxt == 't') sb.append('\t');
                    else throw new RuntimeException("unsupported escape");
                } else {
                    sb.append(c);
                }
            }
            return sb.toString();
        }

        int parseInt() {
            skip();
            int sign = 1;
            if (s.charAt(p) == '-') {
                sign = -1;
                p++;
            }
            int val = 0;
            while (p < s.length() && Character.isDigit(s.charAt(p))) {
                val = val * 10 + s.charAt(p) - '0';
                p++;
            }
            return val * sign;
        }

        List<Integer> parseIntArray() {
            expect('[');
            List<Integer> arr = new ArrayList<>();
            skip();
            if (p < s.length() && s.charAt(p) == ']') {
                p++;
                return arr;
            }
            while (true) {
                arr.add(parseInt());
                skip();
                if (s.charAt(p) == ']') {
                    p++;
                    break;
                }
                expect(',');
            }
            return arr;
        }

        List<List<Integer>> parseMatrix() {
            expect('[');
            List<List<Integer>> mat = new ArrayList<>();
            skip();
            if (p < s.length() && s.charAt(p) == ']') {
                p++;
                return mat;
            }
            while (true) {
                mat.add(parseIntArray());
                skip();
                if (s.charAt(p) == ']') {
                    p++;
                    break;
                }
                expect(',');
            }
            return mat;
        }
    }

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringBuilder input = new StringBuilder();
        String line;
        while ((line = br.readLine()) != null) {
            input.append(line);
        }

        Parser parser = new Parser(input.toString());
        parser.expect('{');

        List<List<Integer>> train = new ArrayList<>();
        List<List<Integer>> test = new ArrayList<>();
        while (true) {
            parser.skip();
            if (parser.p < parser.s.length() && parser.s.charAt(parser.p) == '}') {
                parser.p++;
                break;
            }

            String key = parser.parseString();
            parser.expect(':');
            if ("train".equals(key)) {
                train = parser.parseMatrix();
            } else if ("test".equals(key)) {
                test = parser.parseMatrix();
            } else {
                throw new RuntimeException("unexpected key");
            }

            parser.skip();
            if (parser.p < parser.s.length() && parser.s.charAt(parser.p) == '}') {
                parser.p++;
                break;
            }
            parser.expect(',');
        }

        int n = train.size();
        int m = train.get(0).size() - 1;

        long[] classCnt = new long[2];
        long[] totalWords = new long[2];
        long[][] cnt = new long[2][m];

        for (List<Integer> row : train) {
            int label = row.get(row.size() - 1);
            classCnt[label]++;
            for (int i = 0; i < m; i++) {
                cnt[label][i] += row.get(i);
                totalWords[label] += row.get(i);
            }
        }

        List<Integer> ans = new ArrayList<>();
        for (List<Integer> row : test) {
            double[] score = {-1e100, -1e100};
            for (int c = 0; c < 2; c++) {
                if (classCnt[c] == 0) {
                    continue;
                }
                score[c] = Math.log((double) classCnt[c] / n);
                double denom = totalWords[c] + m;
                for (int i = 0; i < m; i++) {
                    int value = row.get(i);
                    if (value == 0) {
                        continue;
                    }
                    double prob = (cnt[c][i] + 1.0) / denom;
            

剩余60%内容,订阅专栏后可继续查看/也可单篇购买

互联网刷题笔试宝典 文章被收录于专栏

互联网刷题笔试宝典,这里涵盖了市面上大部分的笔试题合集,希望助大家春秋招一臂之力

全部评论

相关推荐

点赞 评论 收藏
分享
评论
2
1
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务