蚂蚁笔试 蚂蚁笔试题 20260312

笔试时间:2026年3月12日

往年笔试合集:

2023春招秋招笔试合集

2024春招秋招笔试合集

第1题

题目

有一个长度为 $n$ 的数组 $a$。总共可以执行以下两种操作,其中第一种操作最多可执行一次,第二种操作也最多可执行一次(两种操作可以组合使用):

  • 选择两个不同的下标 $i, j$,将 $a_i$ 修改为 $a_i + a_j$,花费 $k$ 的代价;
  • 选择一个元素 $a_i$ 以及任意正整数 $d$,对其增加或减少 $d$,花费 $d$ 的代价。

希望花费最小的代价,使得 $\prod a_i = 0$。

请输出这个最小代价。

输入描述

每个测试文件均包含多组测试数据。

第一行输入一个整数 $t$,表示数据组数;

每组测试数据描述如下:

  1. 在一行上输入两个整数 $n, k$;
  2. 在一行上输入 $n$ 个整数 $a_1, a_2, \ldots, a_n$;

除此之外,保证单个测试文件中所有 $n$ 的和不超过 $2 \times 10^5$。

输出描述

对于每组测试数据,新起一行输出一个整数,表示最小花费代价。

样例输入

2
3 5
1 -2 3
4 10
0 0 0 -5

样例输出

1
0

样例说明

对于第一组测试数据只需要将 $a_1$ 减一即可。

参考题解

解题思路:

题目要求通过两种操作(每种最多用一次)使数组所有元素的乘积为 0,并最小化代价。核心在于:乘积为 0 等价于至少有一个元素为 0。因此问题转化为:用最小代价让某个元素变成 0。

有三种方式达成目标:

  • 方式 A(只用操作二): 直接选一个元素 $a_i$,用操作二将其变为 0,代价为 $|a_i|$。因此最小代价是数组中绝对值最小的元素的绝对值,记作 min_single
  • 方式 B(先操作一再操作二): 操作一将 $a_i$ 变为 $a_i + a_j$(代价 $k$),然后操作二将新值变为 0(代价 $|a_i + a_j|$)。总代价为 $k + |a_i + a_j|$。我们需要最小化 $|a_i + a_j|$,即数组中两数之和的绝对值的最小值,记作 min_pair
  • 方式 C(只用操作一): 可归入方式 B 中 min_pair = 0 的情况。

最终最小代价为:$\min(\text{min_single},\ k + \text{min_pair})$

min_pair 的方法:将数组排序后使用双指针法,左指针从最小开始,右指针从最大开始,在 $O(n \log n)$ 时间内找到两数之和的最小绝对值。

C++

#include <bits/stdc++.h>
using namespace std;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int t;
    cin >> t;
    while(t--){
        long long n, k;
        cin >> n >> k;
        vector<long long> a(n);
        for(int i = 0; i < n; i++) cin >> a[i];

        // 找绝对值最小的单个元素
        long long min_single = LLONG_MAX;
        for(int i = 0; i < n; i++){
            min_single = min(min_single, abs(a[i]));
        }

        if(min_single == 0){
            cout << 0 << "\n";
            continue;
        }

        // 排序 + 双指针找两数之和的最小绝对值
        sort(a.begin(), a.end());
        long long min_pair = LLONG_MAX;
        int left = 0, right = n - 1;
        while(left < right){
            long long s = a[left] + a[right];
            min_pair = min(min_pair, abs(s));
            if(s < 0) left++;
            else if(s > 0) right--;
            else break;
        }

        long long ans = min_single;
        if(n > 1){
            ans = min(ans, k + min_pair);
        }
        cout << ans << "\n";
    }
    return 0;
}

Java

import java.util.*;
import java.io.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int t = Integer.parseInt(br.readLine().trim());
        StringBuilder sb = new StringBuilder();
        while (t-- > 0) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int n = Integer.parseInt(st.nextToken());
            long k = Long.parseLong(st.nextToken());
            long[] a = new long[n];
            st = new StringTokenizer(br.readLine());
            for (int i = 0; i < n; i++) {
                a[i] = Long.parseLong(st.nextToken());
            }

            // 找绝对值最小的单个元素
            long minSingle = Long.MAX_VALUE;
            for (int i = 0; i < n; i++) {
                minSingle = Math.min(minSingle, Math.abs(a[i]));
            }

            if (minSingle == 0) {
                sb.append(0).append("\n");
                continue;
            }

            // 排序 + 双指针找两数之和的最小绝对值
            Arrays.sort(a);
            long minPair = Long.MAX_VALUE;
            int left = 0, right = n - 1;
            while (left < right) {
                long s = a[left] + a[right];
                minPair = Math.min(minPair, Math.abs(s));
                if (s < 0) left++;
                else if (s > 0) right--;
                else break;
            }

            long ans = minSingle;
            if (n > 1) {
                ans = Math.min(ans, k + minPair);
            }
            sb.append(ans).append("\n");
        }
        System.out.print(sb);
    }
}

Python

import sys

def solve():
    t = int(sys.stdin.readline())
    out = []
    for _ in range(t):
        n_k_line = sys.stdin.readline().split()
        if not n_k_line:
            break
        n = int(n_k_line[0])
        k = int(n_k_line[1])
        a = list(map(int, sys.stdin.readline().split()))

        # 寻找绝对值最小的单个元素
        min_single = float('inf')
        for val in a:
            if abs(val) < min_single:
                min_single = abs(val)

        if min_single == 0:
            out.append("0")
            continue

        # 排序 + 双指针寻找两数之和的最小绝对值
        a.sort()
        min_pair = float('inf')
        left = 0
        right = n - 1
        while left < right:
            current_sum = a[left] + a[right]
            if abs(current_sum) < min_pair:
                min_pair = abs(current_sum)
            if current_sum < 0:
                left += 1
            elif current_sum > 0:
                right -= 1
            else:
                break

        ans = min_single
        if n > 1:
            ans = min(ans, k + min_pair)
        out.append(str(ans))
    print('\n'.join(out))

solve()

第2题 - 离散型隐马尔可夫模型预测

题目

请在仅使用 NumPy 的前提下,实现离散型隐马尔可夫模型(HMM)的 Viterbi 动态规划。

给定:

  • 初始分布 $\pi$
  • 状态转移矩阵 $A$
  • 观测概率矩阵 $B$(列索引=观测符号)
  • 多条离散观测序列 $obs$,符号取值为整数

请为每条序列计算:

  1. 最优隐藏状态序列
  2. 该序列的对数概率

实现要求

  • 对数域计算,避免下溢:所有乘法用 log 加法,求和用 logsumexpmax
  • 只需前向传递 + 回溯,无需 EM 或 Baum-Welch 训练
  • 所有浮点以 float64 计算;输出对数概率保留 6 位小数(四舍五入)

输入描述

单行 JSON:

{
  "pi":[0.6,0.4],
  "A": [[0.7,0.3],[0.4,0.6]],
  "B": [[0.5,0.4,0.1],[0.1,0.3,0.6]],
  "obs":[[0,0],[2,2]...]
}

$N \leq 50$,$M \leq 50$,$S \leq 200$,序列长度 $\leq 500$。

所有概率矩阵行各自已归一化;不必检验。

输出描述

仅一行 JSON:

{
  "paths":[[q11,q12,...],[q21,...],...],
  "logp":[-2.253795,-2.448768,...]
}

次序须与输入 obs 保持一致。

样例输入

{"pi":[0.6,0.4],"A":[[0.7,0.3],[0.4,0.6]],"B":[[0.5,0.4,0.1],[0.1,0.3,0.6]],"obs":[[0,0]]}

样例输出

{"paths": [[0, 0]], "logp": [-2.253795]}

参考题解

解题思路:

使用 Viterbi 算法在对数域进行动态规划。

定义:

  • $\delta[t][i]$:在时间 $t$,以状态 $i$ 结尾的所有路径中的最大对数概率
  • $\psi[t][i]$:在时间 $t$ 状态为 $i$ 时,时间 $t-1$ 的最优前驱状态(用于回溯)

算法步骤:

  1. 初始化($t=0$):$\delta[0][i] = \log(\pi_i) + \log(B[i][o_0])$
  2. 递推($t=1$ 到 $T-1$):$\delta[t][j] = \max_i(\delta[t-1][i] + \log(A[i][j])) + \log(B[j][o_t])$,同时记录 $\psi[t][j] = \arg\max_i(\delta[t-1][i] + \log(A[i][j]))$
  3. 终止:$\log P^* = \max_j \delta[T-1][j]$,$q^*_{T-1} = \arg\max_j \delta[T-1][j]$
  4. 回溯:从终点状态开始,根据 $\psi$ 数组回溯得到完整路径。

C++

#include <bits/stdc++.h>
using namespace std;

// 简易 JSON 解析辅助(适用于本题简单结构)
// 实际比赛中可使用 nlohmann/json 或手动解析
#include <sstream>

int main(){
    string line;
    getline(cin, line);

    // 手动解析 JSON(简化版,适用于本题格式)
    // 解析 pi
    auto parseArray = [](const string& s, int& pos) -> vector<double> {
        vector<double> res;
        pos = s.find('[', pos);
        pos++;
        while(s[pos] != ']'){
            if(s[pos] == ',' || s[pos] == ' '){ pos++; continue; }
            int end = pos;
            while(end < (int)s.size() && s[end] != ',' && s[end] != ']') end++;
            res.push_back(stod(s.substr(pos, end - pos)));
            pos = end;
        }
        pos++; // skip ']'
        return res;
    };

    auto parseMatrix = [&](const string& s, int& pos) -> vector<vector<double>> {
        vector<vector<double>> res;
        pos = s.find('[', pos);
        pos++; // skip outer '['
        while(s[pos] != ']'){
            if(s[pos] == ',' || s[pos] == ' '){ pos++; continue; }
            res.push_back(parseArray(s, pos));
        }
        pos++; // skip outer ']'
        return res;
    };

    auto parseIntMatrix = [&](const string& s, int& pos) -> vector<vector<int>> {
        vector<vector<int>> res;
        pos = s.find('[', pos);
        pos++; // skip outer '['
        while(s[pos] != ']'){
            if(s[pos] == ',' || s[pos] == ' '){ pos++; continue; }
            // parse inner array
            int p2 = s.find('[', pos);
            p2++;
            vector<int> row;
            while(s[p2] != ']'){
                if(s[p2] == ',' || s[p2] == ' '){ p2++; continue; }
                int end = p2;
                while(end < (int)s.size() && s[end] != ',' && s[end] != ']') end++;
                row.push_back(stoi(s.substr(p2, end - p2)));
                p2 = end;
            }
            p2++; // skip ']'
            res.push_back(row);
            pos = p2;
        }
        pos++;
        return res;
    };

    int pos = 0;
    // find "pi"
    pos = line.find("\"pi\"");
    pos = line.find(':', pos) + 1;
    vector<double> pi = parseArray(line, pos);

    pos = line.find("\"A\"");
    pos = line.find(':', pos) + 1;
    vector<vector<double>> A = parseMatrix(line, pos);

    pos = line.find("\"B\"");
    pos = line.find(':', pos) + 1;
    vector<vector<double>> B = parseMatrix(line, pos);

    pos = line.find("\"obs\"");
    pos = line.find(':', pos) + 1;
    vector<vector<int>> obs = parseIntMatrix(line, pos);

    int N = pi.size();

    // 取对数
    vector<double> log_pi(N);
    for(int i = 0; i < N; i++) log_pi[i] = log(pi[i]);

    vector<vector<double>> log_A(N, vector<double>(N));
    for(int i = 0; i < N; i++)
        for(int j = 0; j < N; j++)
            log_A[i][j] = log(A[i][j]);

    int M = B[0].size();
    vector<vector<double>> log_B(N, vector<double>(M));
    for(int i = 0; i < N; i++)
        for(int j = 0; j < M; j++)
            log_B[i][j] = log(B[i][j]);

    // 输出结果
    cout << "{\"paths\": [";
    vector<string> path_strs, logp_strs;

    for(auto& ob : obs){
        int T = ob.size();
        if(T == 0){
            path_strs.push_back("[]");
            logp_strs.push_back("0.0");
            continue;
        }

        vector<vector<double>> delta(T, vector<double>(N));
        vector<vector<int>> psi(T, vector<int>(N, 0));

        // 初始化
        for(int i = 0; i < N; i++)
            delta[0][i] = log_pi[i] + log_B[i][ob[0]];

        // 递推
        for(int t = 1; t < T; t++){
            for(int j = 0; j < N; j++){
                double best = -1e300;
                int best_i = 0;
                for(int i = 0; i < N; i++){
                    double v = delta[t-1][i] + log_A[i][j];
                    if(v > best){ best = v; best_i = i; }
                }
                delta[t][j] = best + log_B[j][ob[t]];
                psi[t][j] = best_i;
            }
        }

        // 终止
        double log_p_star = -1e300;
        int q_star = 0;
        for(int j = 0; j < N; j++){
            if(delta[T-1][j] > log_p_star){
                log_p_star = delta[T-1][j];
                q_star = j;
            }
        }

        // 回溯
        vector<int> path(T);
        path[T-1] = q_st

剩余60%内容,订阅专栏后可继续查看/也可单篇购买

2025 春招笔试合集 文章被收录于专栏

2025打怪升级记录,大厂笔试合集 C++, Java, Python等多种语言做法集合指南

全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务