Reference: LeetCode
Difficulty: Hard

Problem

Given a non-empty binary tree, find the maximum path sum.

For this problem, a path is defined as any sequence of nodes from some starting node to any node in the tree along the parent-child connections. The path must contain at least one node and does not need to go through the root.

Example:

1
2
3
4
5
6
7
Input: [1,2,3]   Input: [-10,9,20,null,null,15,7]
1 -10
/ \ / \
2 3 9 20
/ \
15 7
Output: 6 Output: 42

Also, notice the following cases:

1
2
3
[-1]        -1
[-1,-2,-3] -1
[-6,2,3,4] 6

Analysis

Methods:

First, let’s understand the idea of maxGain(node).

Reference: link

maxGain(node)

Most importantly, we need to design a maxGain(node) that can handle negative numbers. The idea is as follows:

However, just using maxGain(node) could not give us the correct result, because the max sum path does not necessarily go through node. So we can use a field to store this value.

  1. Recursion
    • This problem could be simplified by implementing a function maxGain(node) which computes the maximum path sum from this node to one node below (it could be a non-leaf node). Here are the rules inside maxGain(node)
      • Root must be used.
      • At most one child can be used, or not used when both have negative gains.
    • If one would know that the max sum path contains node, the problem would be solved as maxGain(node). However, this path does not necessarily go through the node. A max sum path could be a path at the bottom not crossing the node.
    • Thus, we need to modify the function and to check which path is better and update the maximum value if necessary (a field variable). As for maxGain(node), it still returns the maximum gain of the path that crosses the node.
    • Note: The idea is not easy, but not hard to understand. However, how to manage the design of recursive functions and handle corner cases is super tricky and complicated if not using the trick (int leftGain = Math.max(maxGain(node.left), 0)).
    • Time: $O(N)$ since we visit each node for no more than $2$ times.
    • Space: $O(h)$

Code

Original Version

Note: 140 ms

  • The time complexity is $O(N^2)$ in the worst case, since it does a lot of repeated calculations.
  • It just separate the checking idea from the maxGain(node) and make it into maxPathSum(node). Actually, it turns out that we could do it in one function.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public int maxPathSum(TreeNode root) {
if (root == null) {
return 0;
}
// root node must be included
int lMax = maxPathSumOnOneSide(root.left);
int rMax = maxPathSumOnOneSide(root.right);
int p1 = root.val + lMax + rMax;
int p2 = root.val + Math.max(lMax, rMax);
int p3 = root.val;
int maxVal = Math.max(p1, Math.max(p2, p3));

if (root.left != null) {
maxVal = Math.max(maxVal, maxPathSum(root.left));
}
if (root.right != null) {
maxVal = Math.max(maxVal, maxPathSum(root.right));
}

return maxVal;
}

private int maxPathSumOnOneSide(TreeNode root) {
if (root == null) {
return 0;
}
int lMax = maxPathSumOnOneSide(root.left);
int rMax = maxPathSumOnOneSide(root.right);
int p1 = root.val + lMax;
int p2 = root.val + rMax;
int p3 = root.val;
return Math.max(p1, Math.max(p2, p3));
}

Initial Improvement

Note: 2 ms

  • The reason why this code is too long and complicated is because I check a lot of corner cases for negative numbers. For example:
    • [-3] should return -3 rather than 0. root == null will lead to return 0, but it should not be treated as a valid return value because it is not a node at all. In other words, 0 should not be compared with -3.
    • [2, -1] should return 2 rather than 2 + (-1) = 1.
  • Therefore, when calculating the max sum, I have to do many combinations of the results of root.val, root.left, root.right.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
private Integer maxVal;

public int maxPathSum(TreeNode root) {
maxVal = null;
helper(root);
return maxVal;
}

private int helper(TreeNode root) { // return the max gain
if (root == null) {
return 0;
}

if (root.left == null && root.right == null) {
maxVal = (maxVal == null) ? root.val : Math.max(maxVal, root.val);
return root.val;
}

int lMax = helper(root.left);
int rMax = helper(root.right);
int p1 = root.val + lMax;
int p2 = root.val + rMax;
int p3 = root.val;

// update the sum max
int oneMax = root.val;
if (root.left != null) {
oneMax = Math.max(oneMax, lMax);
}
if (root.right != null) {
oneMax = Math.max(oneMax, rMax);
}
int twoMax = root.val + Math.max(lMax, rMax);
int threeMax = root.val + lMax + rMax;
int newMax = Math.max(oneMax, Math.max(twoMax, threeMax));
maxVal = (maxVal == null) ? newMax : Math.max(maxVal, newMax);

return Math.max(p1, Math.max(p2, p3));
}

Clean Solution

Note: 1 ms

  • Learn how to calculate the maximum values among several values. It turns out that if maxGain(node) returns a negative number, we would definitely not to choose it as our part of values. With this simple code, it handles all the cases as you can imagine.
  • If the maxGain(node) of a node’s left or right child is negative, we will immediately drop that value and start over from node.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
private Integer maxSum;

public int maxPathSum(TreeNode root) {
maxSum = Integer.MIN_VALUE;
maxGain(root);
return maxSum;
}

private int maxGain(TreeNode root) {
if (root == null) {
return 0;
}

// max gain on the left and right subtrees of current node
int leftGain = Math.max(maxGain(root.left), 0); // if they are negative, set 0
int rightGain = Math.max(maxGain(root.right), 0); // suppress negative results

// it is better to start a new path crossing the current node? (Can use both sides)
int newPathSum = root.val + leftGain + rightGain; // don't need to check root.val + leftGain
this.maxSum = Math.max(this.maxSum, newPathSum);

// return max gain (Can only use one side)
return root.val + Math.max(leftGain, rightGain);
}