[LeetCode] 讓程式快2倍的解法 #15 3Sum

這題稍微分析一下數學就可以讓程式快 2 倍,但大家好像都沒這樣做。

"簡單的"數學分析

題目要我們找三個數字加起來等於 0:


num1 + num2 + num3 = 0

如果我們每個數字分成正負號的話,什麼情況下這個式子一定無法成立?

我們要加快程式就是要省去一些運算,所以如果有一開始就知道不用算的東西那我們就最開心了。

我們會發現 3 個數字都正號或 3 個數字都負號你連算都不用算,因為全部加起來不是大於 0 就是小於 0:


num1(-) + num2(-) + num3(-) < 0    -> No solution
num1(+) + num2(+) + num3(+) > 0    -> No solution

那麼剩下的情況我們就一定有正有負:

num1(-) + num2(-) + num3(+) = 0
num1(+) + num2(+) + num3(-) = 0

知道了這件事我們就可以不用傻傻的跑去計算 3 個正數或是 3 個負數的情況了。

想法: 找 num3

數學為什麼有用在這邊就又得了一分,你看到上面都只剩兩種可能性了,那演算法已經呼之欲出了:

  1. 我們用任意 2 個負數來找剩下 1 個正數存不存在
  2. 我們用任意 2 個正數來找剩下 1 個負數存不存在

這時候我們在想演算法的時候大概會這樣想:

  • 任意 2 個數? 那我就用 2 個 for loop 來暴力組合兩個數字
  • 來找一個數字在不在? 我可以用 hash table 用 O(1) 時間找到他

可是 for loop 暴力解通常是建立在數字已經從小到大排好,而且沒有重複的數字的前提下,所以我們大概是躲不掉排序 + 移除重複的數字的這兩個步驟。

演算法: 找 num3

在移除重複的數字前我們先想要知道每個數字有幾個,這個步驟很重要我們後面就知道為什麼:


然後我們就 sort + 移除重複的數字:


接著我們就試著把所有可能的 2 個負數 num1 + num2 加起來,再來找正的 num3 在不在:


(如果影片無法播放可以來看 PowerPoint 投影片)

再來我們試著把所有可能的 2 個正數 num1 + num2 加起來,再來找負的 num3 在不在:


實作: 用 map

Hash table 要用什麼資料結構? 第一個想到的當然是 map,那我們就用 map 來做吧:


#include <map>

class Solution {
 public:
  std::vector<std::vector<int>> threeSum(std::vector<int> &nums) {
    std::vector<std::vector<int>> output;
    std::map<int, int> negativeCount;
    std::map<int, int> positiveCount;
    for (const int num : nums) {
      if (num < 0) {
        addMapBy1(negativeCount /*will change*/, num);
      } else {
        addMapBy1(positiveCount /*will change*/, num);
      }
    }

    // 2 negative numbers + 1 positive number
    for (const auto &[num1, count1] : negativeCount) {
      for (auto it = negativeCount.find(num1); it != negativeCount.end();
           it++) {
        const int num2 = it->first;
        if (num1 == num2 && count1 <= 1) continue;

        const int num3 = -(num1 + num2);
        if (num3 < 0) continue;
        if (positiveCount.find(num3) == positiveCount.end()) continue;

        output.push_back({num1, num2, num3});
      }
    }

    // 1 negative number + 2 positive numbers
    for (const auto &[num1, count1] : positiveCount) {
      for (auto it = positiveCount.find(num1); it != positiveCount.end();
           it++) {
        const int num2 = it->first;
        if (num1 == num2 && count1 <= 1) continue;

        const int num3 = -(num1 + num2);
        if (num3 >= 0) continue;
        if (negativeCount.find(num3) == negativeCount.end()) continue;

        output.push_back({num1, num2, num3});
      }
    }

    // Special case: 3 zeros
    if (positiveCount.find(0) != positiveCount.end()) {
      if (positiveCount.at(0) >= 3) {
        output.push_back({0, 0, 0});
      }
    }

    return output;
  }

  void addMapBy1(std::map<int, int> &count, const int num) {
    if (count.find(num) == count.end()) count[num] = 0;
    count[num]++;
  }
};

最後一個 special case 是因為我把 0 不當作是正數或負數,但他卻在 positiveCount 裡面所以不特別處理他就會得到一個 Wrong Answer 當 nums = {0, 0, 0} 的時候。

可是 submit 完就發現速度不是很理想:


Callgrind 分析後才發現 47% 左右的時間花在做每次 map::find 的時候,那兩個 map::find 加起來幾乎 85% 的時間都花在 map::find:


我們會以為大部分的時間都花在兩層 for loop 上,看來我們就必須解決 "找東西" 的 bottleneck。

實作: 改用 unordered_map?

題目有說一定要按照什麼順序輸出嗎? 既然沒有我就不用一定要用 map。

我們上網找一下 map vs unordered_map 的 lookup 時間是不是有差異,結果看其他人的 benchmark 還真的有,unordered_map 在 fetch 的時間上通常比較快。

於是我們就把 map 換成 unordered_map 看看,程式碼我就不貼了,可以在 GitHub 看,結果發現是有比較快啦:


但貪心的我覺得還不夠快,再用 Callgrind 分析後發現他還是大部分時間都花在找東西上:


看來要更快就只能放棄好用的 STL containers。

實作: 改用 Array?

要追求速度有時候就是得用最原始的東西,那 hash table 最原始的型態是什麼? 就是 array。

看到題目說數字的範圍這麼小的時候我其實就見獵心喜想要用 array 來做 hash table 了:


如果我們用一個超大的 array 來當 hash table 可以嗎? 看 10^5 就覺得不會花太多記憶體,應該是可以吧。於是我們可以把上面的解法改用 array 來寫:


class Solution {
 public:
  std::vector<std::vector<int>> threeSum(std::vector<int> &nums) {
    std::vector<std::vector<int>> output;

    for (const int num : nums) {
      if (num < 0) {
        m_negativeCount[OFFSET + num]++;
      } else {
        m_positiveCount[OFFSET + num]++;
      }
    }

    std::sort(nums.begin(), nums.end());
    // Erase the duplicate numbers
    nums.erase(std::unique(nums.begin(), nums.end()), nums.end());
    const int numsSize = nums.size();

    // 2 negative numbers + 1 positive number
    size_t index1 = 0;
    for (; index1 < numsSize; index1++) {
      const int num1 = nums[index1];
      if (num1 >= 0) break;
      for (size_t index2 = index1; index2 < numsSize; index2++) {
        if (index1 == index2 && m_negativeCount[OFFSET + num1] <= 1) continue;
        const int num2 = nums[index2];
        if (num2 >= 0) break;

        const int num3 = -(num1 + num2);
        if (num3 < 0 || num3 > MAX_VALUE) continue;
        if (m_positiveCount[OFFSET + num3] <= 0) continue;

        output.push_back({num1, num2, num3});
      }
    }

    // 1 negative number + 2 positive numbers
    for (; index1 < numsSize; index1++) {
      const int num1 = nums[index1];
      if (num1 < 0) continue;
      for (size_t index2 = index1; index2 < numsSize; index2++) {
        if (index1 == index2 && m_positiveCount[OFFSET + num1] <= 1) continue;
        const int num2 = nums[index2];

        const int num3 = -(num1 + num2);
        if (num3 < MIN_VALUE || num3 >= 0) continue;
        if (m_negativeCount[OFFSET + num3] <= 0) continue;

        output.push_back({num1, num2, num3});
      }
    }

    // Special case: 3 zeros
    if (m_positiveCount[OFFSET + 0] >= 3) {
      output.push_back({0, 0, 0});
    }

    return output;
  }

  static const int MIN_VALUE = -100000;
  static const int MAX_VALUE = 100000;
  static const int OFFSET = (-MIN_VALUE);
  static const int ARRAY_SIZE = MAX_VALUE - MIN_VALUE + 1;

  int m_positiveCount[ARRAY_SIZE] = {0};
  int m_negativeCount[ARRAY_SIZE] = {0};
};

雖然會比較醜,而且失去了 STL containers 的好處 (像是檢查範圍),負數的時候還要自己 offset 到 0 以上。但是我們既然要追求速度就必須付出一些代價。

但是結果速度就真的快很多:


幾乎已經超越其他 99% 的 code。

Q: 為何會比 Two Pointers 解法還要快?

這個算法一樣是 O(n^2) 和常見的解法 Two Pointers 也是 O(n^2)。而且我們也同樣有 sort + 移除重複的數字。我們隨便抓一個 Two Pointers 改良版丟上去都還是慢了 2 倍以上,為何這個解法比較快?

我能想到的主因就是因為這個解法 "找東西" 找的比較快。

Two pointers 解法遇到一個例子就會找很久,例如 num1 = -50, num2 = -50,我想要找 100 在不在,那我要讓 Two Pointers 難堪就是加入很多接近 100 卻不是 100 的數字: 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, ...

Two pointers 如果覺得目標應該在右邊就會慢慢往右邊走,覺得目標在左邊就會慢慢往左邊走。但是我們的這個解法可以直接一步知道 100 到底在不在。

你可能會覺得說可是我們這個解法從 2 個正數出發也是會碰到這些 95, 96, 97, ... 數字啊!那我針對 num1 = 50, num2 = 50 我也造一堆 -95, -96, -97, ... 同理讓 two pointers 難堪,如果 two pointers 從負數出發就躲不掉了。但我們這個解法至少從任何一邊下手都可以迴避掉這些數字。

再加上一開始的數學分析讓我們根本就不用跑一些組合,所以比較快。

Q: 沒有分負數/正數會比較慢嗎?

你可能會懷疑說是因為 hash table 用 array 做所以才比較快的吧!和數學沒有關係?

其實我一開始就用 array 卻沒有數學分析的版本,結果就慢了一點:


實作: Clean Code

改成用 array 之後程式碼就突然變了很亂,如果我在公司的話這樣寫鐵定 code review 不會過,所以我可能會這樣寫:


class Solution {
 public:
  std::vector<std::vector<int>> threeSum(std::vector<int> &nums) {
    countNumbers(nums);

    sortAndMakeUnique(nums /*may change*/);

    findOnePositive(nums);
    findOneNegative(nums);
    findOneZero();
    return m_output;
  }

  void countNumbers(const std::vector<int> &nums) {
    for (const int num : nums) {
      if (num < 0) {
        m_negativeCount[NEGATIVE_OFFSET + num]++;
      } else {
        m_nonNegativeCount[NON_NEGATIVE_OFFSET + num]++;
      }
    }
  }

  void sortAndMakeUnique(std::vector<int> &nums) {
    std::sort(nums.begin(), nums.end());
    nums.erase(std::unique(nums.begin(), nums.end()), nums.end());
  }

  void findOnePositive(const std::vector<int> &nums) {
    // num1(-) + num2(-) + num3(+) = 0
    const int numsSize = nums.size();
    for (size_t index1 = 0; index1 < numsSize; index1++) {
      const int num1 = nums[index1];
      if (num1 >= 0) break;
      for (size_t index2 = index1; index2 < numsSize; index2++) {
        if (index1 == index2 && m_negativeCount[NEGATIVE_OFFSET + num1] <= 1)
          continue;
        const int num2 = nums[index2];
        if (num2 >= 0) break;

        const int num3 = -(num1 + num2);
        if (num3 < 0 || num3 > MAX_VALUE) continue;
        if (m_nonNegativeCount[NON_NEGATIVE_OFFSET + num3] <= 0) continue;

        m_output.push_back({num1, num2, num3});
      }
    }
  }

  void findOneNegative(const std::vector<int> &nums) {
    // num1(+) + num2(+) + num3(-) = 0
    const int numsSize = nums.size();
    for (size_t index1 = 0; index1 < numsSize; index1++) {
      const int num1 = nums[index1];
      if (num1 < 0) continue;
      for (size_t index2 = index1; index2 < numsSize; index2++) {
        if (index1 == index2 &&
            m_nonNegativeCount[NON_NEGATIVE_OFFSET + num1] <= 1)
          continue;
        const int num2 = nums[index2];

        const int num3 = -(num1 + num2);
        if (num3 < MIN_VALUE || num3 >= 0) continue;
        if (m_negativeCount[NEGATIVE_OFFSET + num3] <= 0) continue;

        m_output.push_back({num1, num2, num3});
      }
    }
  }

  void findOneZero() {
    if (m_nonNegativeCount[NON_NEGATIVE_OFFSET + 0] >= 3) {
      m_output.push_back({0, 0, 0});
    }
  }

  static const int MIN_VALUE = -100000;
  static const int MAX_VALUE = 100000;
  static const int NEGATIVE_OFFSET = (-MIN_VALUE);
  static const int NON_NEGATIVE_OFFSET = 0;

  int m_negativeCount[-1 - MIN_VALUE + 1] = {0};
  int m_nonNegativeCount[MAX_VALUE - 0 + 1] = {0};
  std::vector<std::vector<int>> m_output;
};

而且我之前偷懶不想分正負數的 hash table 所以就乾脆大家都一樣大,乾淨版的就有考慮到要省記憶體,所以有特定分 NEGATIVE_OFFSETNON_NEGATIVE_OFFSET。名稱也改了一下,不然之前 positive 裡面卻有 0 就怪怪的。

結語

解 LeetCode 真的很好玩,要是我還沒進公司前來解這題大概是不會想到數學分析的方法的,光是寫 code 就夠難了還跟你玩數學?

可是這次超快的速度再次證明了數學的有用,不過還是要對 STL 和資料結構有些基礎認識 (加上一些聽說這樣那樣比較慢),兩者綜合在一起才會效果加倍。

還是鼓勵大家在動手寫程式前先想一下為何要這樣做,有沒有其他的辦法,不要網路上說要做 two pointers 就一定只能這樣做。

有興趣的話還可以看下一題 3Sum Closest 我們一樣可以用數學分析來加速。

留言

此網誌的熱門文章

[試算表] 追蹤台股 Google Spreadsheet (未實現損益/已實現損益)

[Side Project] 互動式教學神經網路反向傳播 Interactive Computational Graph

[插件] 在 Chrome 網頁做區分大小寫的搜尋