题解 | 打家劫舍(三)
打家劫舍(三)
https://www.nowcoder.com/practice/58dad1054a0b41ab9b076e5bcc3160dc
import java.util.Scanner; import java.util.ArrayList; // 注意类名必须为 Main, 不要有任何 package xxx 信息 public class Main { public static void main(String[] args) { Scanner in = new Scanner(System.in); int n = in.nextInt(); int[] value = new int[n]; int[] parent = new int[n]; for (int i = 0; i < n; i++) value[i] = in.nextInt(); for (int i = 0; i < n; i++) parent[i] = in.nextInt(); if(n==1){ System.out.println(value[0]); return; } ArrayList<Integer>[] relations = new ArrayList[n]; //relations[i]表示节点i的子节点 for (int i = 1; i < n; i++) { if(relations[parent[i]-1]==null) relations[parent[i]-1] = new ArrayList(); relations[parent[i]-1].add(i); } int[] res = maxPasserano(value,relations,0); System.out.println(Math.max(res[0],res[1])); } private static int[] maxPasserano(int[] value,ArrayList<Integer>[] relations, int curNodeId){ int[] leftres = new int[2]; int[] rightres = new int[2]; int max = Integer.MIN_VALUE; //System.out.println(curNodeId); int[] temp = new int[2]; if(relations[curNodeId]==null||relations[curNodeId].size()==0) { temp[0] = value[curNodeId]; temp[1] = 0; return temp; } if(relations[curNodeId].size()>=1) leftres = maxPasserano(value,relations,relations[curNodeId].get(0)); if(relations[curNodeId].size()==2) rightres = maxPasserano(value,relations,relations[curNodeId].get(1)); temp[0] = value[curNodeId]+leftres[1]+rightres[1]; temp[1] = Math.max(leftres[0],leftres[1])+Math.max(rightres[0],rightres[1]); return temp; //return Math.max(left,right) + value[curNodeId]; } }