Looks like it's taking quite long...

◦ Click anywhere to close

U²Net Tensorflow.js

31 December 2020

3 min read

Table of Contents

1. Background

2. Machine Learning in the Browser

3. Demo

Background

U²Net is a fascinating neural network architecture for salient object detection, which is basically the detection and segmentation of images. In other words, the network tries to extract interesting regions from images which leads to applications like background removal and replacement effects.


        However, a downside to models like U²Net is that they usually require a computing environment equipped with large dependencies like Pytorch or Tensorflow. In the case of websites, these models are typically run server-side with the output sent back to the front-end. So the thought occured to me: Was there a way to run them fully browser-side without any server dependencies?

Machine Learning in the Browser

The answer was yes: Tensorflow.js is a mature library that can run Tensorflow-based ML models in the browser. However, the U²Net model was implemented in Pytorch and so the available pre-trained weights were also in Pytorch format (.pth). Thus, I searched for ways to bridge this gap and came to a holy grail: ONNX which, according to its homepage, is "an open format built to represent machine learning models". At last, I came up with the following procedures:

  1. Convert Pytorch weights (.pth) to ONNX weights (.onnx)
  2. Convert ONNX weights (.onnx) to Tensorflow's SavedModel format
  3. Convert Tensorflow's SavedModel to Tensorflow.js weights

Demo

With some experimentation, the above procedures bore fruit and I had a working Tensorflow.js model that could be loaded on the browser. I built a demo which pre-processes input images, feeds them to the model and displays the result.


        OpenCV.js was used to resize the image to a suitable input dimension for the model. The model's output was a gray scale mask with different opacity values so I used OpenCV.js to round the opacity to either solid black or white, also known as the trimmed mask. Finally, the trimmed mask is then multiplied with the original image such that only the white regions are retained.


        To avoid stalling the main thread, I ran the model in a Web Worker. The model's segmentation abilities are quite rough on the edges compared to the original, but it was a satisfying attempt running this entirely on the browser.


Here's a video of the demo:

And finally a meme: