[Side Project] 互動式教學神經網路反向傳播 Interactive Computational Graph
成品
為何要做這套工具
以前在學 machine learning 的時候覺得最難理解的部分就是 backpropagation,也就是神經網路到底是怎麼用數學來做訓練的。
我們知道神經網路就是長這樣子:
Source: Wikipedia |
但實際上每個 neuron (圓圈圈) 裡面有許多基本的數學運算,像是 sum 還有 activation function:
Source: Neural Networks and Machine Learning |
當我們把所有東西綜合在一起,腦袋就會打結,到底 chain rule 是怎麼在這麼複雜的神經網路之間套用的? 就算了解基礎也很難相信實際上到底怎麼運作的。我又沒有要真的使用的話,網路上的教學就覺得看看就好
Source: A Step by Step Backpropagation Example |
但決心要了解 backpropagation 細節之後,不少教學都提到了 computational graph,從基礎的數學運算當作 nodes 來去理解似乎是比較簡單的方式。
後來找到了這篇文章 Calculus on Computational Graphs: Backpropagation 把 chain rule 講解的簡單至極,於是就在想,為何我不做一個網頁,讓大家可以自由地組合想要的 graph,然後看一下 chain rule 是怎麼流通的。
失敗作經驗
在做這次 side project 前,其實早在一年前我就開始作另一個 side project 是想展現神經網路的訓練過程。名稱叫做 explain-derivatives,是被 TensorFlow 的一個網頁 A Neural Network Playground 啟發的。
但由於一開始欠缺規劃,想到什麼就加什麼功能進去,造成架構一直修改,最後由於程式碼架構太過複雜就放棄了。
這個 project 是用 Cytoscape.js 為底,加上 TypeScript,可以打造任何形狀的 fully connected neural network,目的是 step-by-step 觀看 neural network 的訓練過程,配上數學解說。
後來到了新公司,看了 Cleanr Architecture 的書,還有在 Discord 上面與其他人打造網頁的經驗後,這次的 project 一開始就決定要嚴格按照軟體的流程來走。
當真的 Project 在做
有了失敗作的經驗,這次 side project 除了在 GitHub project 的 specs 資料夾先寫各種 specs 後,還用 Figma 製作了網頁的 prototypes:
雖然 side project 就只有我在做,但我想像成這是目前公司實際上的一個 project。於是就把自己當成是不同 teams 的不同人,開不同的 GitHub issues 與設定 milestones,GitHub issues 把工作切成很多小塊,milestones 則是把時間切成很多不同衝刺的目標。
後來驗證了這個方法的確管用,一開始程式碼還少的時候覺得要寫任何東西都很簡單,但到中期大部分核心功能都完成,開始做與 UI 比較相關的功能時就覺得程式碼怎麼突然變得這麼巨大,只要隔一周沒碰再來看就真的覺得之前的程式碼是別人寫的。
不過好在每個 GitHub issue 都專心在做特定的功能,所以就算在一堆程式碼中加入新功能也不容易喪失方向,自己也會覺得目標很快就可以達成了。
Clean Architecture
這次 side project 也有按照 Uncle Bob 的 clean architecture 架構來好好規劃,這個 project 中的每個 components 在 clean architecture 的角色會是:
- Core (*.ts,純 TypeScript): 核心演算法,包含 derivatives 計算還有解釋的部分,和 UI 完全無關
- ReactFlow & Features: 依靠核心演算法的零件,像是這次用的第三方 library ReactFlow 還有一些 use cases 而衍生出來的 features
- Components (React components): 依靠其他部分的 UI 零件,不會直接使用到 core 的零件
我覺得最有趣的概念就是,一開始我先做核心演算法,一旦做完了這個 project 的成功率就基本上保證一半,因為至少我知道概念行得通。
再往上疊只不過是讓使用者可以透過 UI 看的到演算法的結果。
而且越外圍的東西越可以被取代,如果哪一天 ReactFlow graph 的第三方軟體沒更新了,我也可以考慮替換成其他的軟體,而不會影響到核心演算法。或是哪天考慮不要用 ReactJS 做了? 也可以換成其他的軟體。
插件系統
在規劃的時候就想好要讓使用者可以自由地編輯或新增各種可以微分的 operators,這也代表核心演算法的架構需要好好地被設計。
於是就設計出類似下面這樣的架構,最重要的部分是 Operation 和 Graph / GraphNode 是完全沒有關係的,也就是說當核心演算法的 Graph 想要計算某個 node 的 derivative 時,他完全不知道operation 的細節。
Source: GitHub issue #4 |
我們也可以用 clean architecture 來思考,Operation 是最內層 Core 中比較外環的,而其他 Graph / GraphNode 則是 Core 中最內層的部分。
反過來想,如果我們讓 Graph 知道 Operation 的存在,乍看似乎好像也不會怎麼樣。不過如果我們把所有細節都混在一起,在不斷開發過程裡面只要一不小心就會讓兩個獨立的概念綁在一起 (coupled),要一直保持警覺地把兩個東西分開是很困難的,不如從架構上一開始就限制他們的地位。
Lifting State Up
React states 到底該放在哪? 一開始我有兩個選擇:
- 集中管理放在最上層的 React component
- 分別方在不同的 React components 之中
分別放在不同的 React components 一開始最吸引我,因為那代表我不用看到一個檔案裏面有一大堆 useState。
但是如果分開放的話資料流在不同的 components 之間就要往上又往下,就像下面這張初始的架構中不同的 components 之間要溝通的話,介面會很麻煩。但我也不想使用 global state,不然測試會比較麻煩。
Source: GitHub issue #8 |
於是就還是決定把 states 都集中在目前最高層的 GraphContainer 中 (和上面這張圖不太一樣)。
雖然 GraphContainer 中有 15 個 useStates,但大部分處理的程式碼都被 delegate 到其他的 components 處理。GraphContainer 就像是 switch 一樣的角色。
也由於 states 都是從上往下傳遞,unit tests 的部分也非常好寫,只要 properties 更換一下就好了。
測試無法被 Mock 的第三方元件 ReactFlow
以前 side projects 沒有很嚴格地做 unit tests,這次就想來試試如果所有的 components 都有 unit tests 會發生什麼事。
雖然都要寫 tests 很麻煩,但結果確實還抓到不少以為新加一個小功能不會有什麼 bugs,卻被舊的 unit tests 給抓到。
但最令人頭痛的是第三方 library 根本無法測試的問題,自己的元件要為了測試可以進行各種改寫,但第三方軟體要要求他實在很難。
這次最難測試的就是網頁中 graph 的部分,是透過 ReactFlow 去 render 的。但在測試環境下光是 render ReactFlow 就已經有難度了,而且要去操作 graph 中 nodes 的連線更是難上加難。如果把 ReactFlow 整個 mock 掉也不是我希望的,因為他是網頁中非常重要的一環,就算有一些小 bug 也會大大地影響到使用者體驗。
於是我就採取一種旁敲側擊的測試手法,與其把他 mock 掉,不如加上一個測試用的 helper (ReactFlowGraphTestHelper),和 ReactFlow 一起在測試環境中 render 出來。架構上會是這樣子:
- GraphContainer (nodes & edges state)
- <ReactFlowGraphTestHelper nodes={nodes} edges={edges} onConnect={handleConnect} />
- <ReactFlowGraph nodes={nodes} edges={edges} onConnect={handleConnect} />
我就可以在 unit test 中去操控 ReactFlowGraphTestHelper 內部很好被操控的 TextField 和 Buttons,觸發 onConnect 讓 GraphContainer 以為是 ReactFlowGraph 發出來的。
實際上要測試的部分是 GraphContainer 接收到 handleConnect 後怎麼去更新 nodes & edges,以及他把更新的 nodes & edges 餵到第三方 ReactFlowGraph 後會不會有什麼問題。
缺點就是 onConnect 發出來的資訊要跟 ReactFlow 實際發出來的一樣,所以未來還是有可能會有 break 的風險,但至少其他的部分可以測的到。
擴展性
實務上 ML 在訓練雖然也是用 computational graph,例如 PyTorch 是用 autograd 的核心套件在做 differentiation,但通常資料不會只有一個 scalar 的數值,常常會至少是 2D 以上的維度 (多一個維度例如是多個 samples),這樣才可以用 GPU 平行運算。
為了讓 core 元件也保有這個彈性,所以你可能會發現數值的部分我都是用 string 而不是 number 在傳,目的就是為了讓更高維度的資料也可以用 string encode 起來,反正 operations 可以自由地被使用者修改,要怎麼 decode 以及 differentiate 是使用者的自由。
唯一使用者必須要改 code 的地方是 derivatives 要如何相加,目前只有這部分是預設 string 中的資料都是 scalar,如果是 2D 以上的維度就是要一個一個元素相加。
結論
透過這次的 side project,我可以動手實作 clean architecture 並了解到他真正的優點。
我工作內容也不是 ReactJS 方面,大部分都是 ChatGPT 幫助的,第一次感覺到好像在與 AI 一起協作,把自己當成不同團隊的人,每個 GitHub issue 都可能要從不同的角度面對問題,是到目前為止做過最有挑戰性的一次 side project。
留言
發佈留言