logo
×

Stanislav IssayenkoDecember 19, 2024

Bringing AI to the browser: SAM2 for interactive image segmentation

Introduction

In computer vision, image segmentation is crucial for object recognition, image editing, autonomous driving, and other common applications. The Segment Anything Model 2 (SAM2) pushes the boundaries of interactive image segmentation by allowing users to segment objects in images with minimal input. 

Traditionally, running such models required powerful servers or specialized hardware. However, with advancements in web technologies and machine learning libraries, it's now possible to run complex models like SAM2 directly in a browser.

In this article, we explore how to implement the SAM2 model in a web browser using ONNX Runtime Web (ort). We delve into the architecture of SAM2, how to load and run the model in a browser, and how to create an interactive user interface for real-time image segmentation.

Understanding the SAM2 architecture

Encoder-decoder framework

SAM2 uses an encoder-decoder architecture:

  • Encoder: Processes the input image to generate a high-dimensional embedding that captures essential features. 
  • Decoder: Takes the embedding and user-provided points (positive and negative) to generate segmentation masks.

This architecture allows for interactive and highly accurate segmentation, while also allowing users to  iteratively refine the segmentation by adding more points.

Why Run SAM2 in the Browser?

Running SAM2 in the browser offers several key benefits:

  • Privacy: Images are processed locally, ensuring user data isn't sent to external servers. 
  • Accessibility: Users can access the segmentation tool without installing specialized software.
  • Interactivity: Real-time feedback enhances the user experience, allowing for quick iterations.

Implementing the encoder (encoder.js)

Initialization

The encoder loads the ONNX model and prepares it for inference:

const ENCODER_MODEL_URL = 'https://storage.googleapis.com/lb-artifacts-testing-public/sam2/sam2_hiera_tiny.encoder.ort';

class SAM2Encoder {
  constructor() {
    this.session = null;
  }

  async initialize() {
    this.session = await ort.InferenceSession.create(ENCODER_MODEL_URL);
    console.log('Encoder model loaded successfully');
  }
}

Image preprocessing

Before passing the image to the encoder, it must be resized and normalized:

  • Resize: Adjust the image to 1024x1024 pixels.
  • Normalize: Scale pixel values to the [-1, 1] range.
imageDataToTensor(image) {
  const canvas = document.createElement('canvas');
  const ctx = canvas.getContext('2d');
  canvas.width = canvas.height = 1024;

  ctx.drawImage(image, 0, 0, 1024, 1024);
  const imageData = ctx.getImageData(0, 0, 1024, 1024).data;
  const inputArray = new Float32Array(3 * 1024 * 1024);

  for (let i = 0; i < 1024 * 1024; i++) {
    inputArray[i] = (imageData[i * 4] / 255.0) * 2 - 1; // R channel
    inputArray[i + 1024 * 1024] = (imageData[i * 4 + 1] / 255.0) * 2 - 1; // G channel
    inputArray[i + 2 * 1024 * 1024] = (imageData[i * 4 + 2] / 255.0) * 2 - 1; // B channel
  }

  return new ort.Tensor('float32', inputArray, [1, 3, 1024, 1024]);
}

Encoding process

The encode method runs the model and generates the image embedding:

async encode(image) {
  const tensor = this.imageDataToTensor(image);
  const feeds = { image: tensor };
  const results = await this.session.run(feeds);
  this.lastEmbeddings = results.image_embed;
  return this.lastEmbeddings;
}

Implementing the decoder (decoder.js)

Initialization

Similar to the encoder, the decoder loads its ONNX model:

const DECODER_MODEL_URL = 'https://storage.googleapis.com/lb-artifacts-testing-public/sam2/sam2_hiera_tiny.decoder.onnx';

class SAM2Predictor {
  constructor() {
    this.session = null;
  }

  async initialize() {
    this.session = await ort.InferenceSession.create(DECODER_MODEL_URL);
    console.log('Decoder model loaded successfully');
  }
}

Preparing inputs

The decoder requires several inputs, including the image embedding and user interaction points:

  • Point coordinates (point_coords): The (x, y) positions of user clicks.
  • Point labels (point_labels): Indicates positive (foreground) or negative (background) points.
  • Mask input (mask_input): An initial mask, set to zeros if not using a previous mask.
  • Other tensors: Additional required inputs like has_mask_input, high_res_feats_0, and high_res_feats_1.
prepareInputs(embedding, points) {
  const numLabels = 1;
  const numPoints = points.length;
  const pointCoordsData = [];
  const pointLabelsData = [];

  for (let point of points) {
    pointCoordsData.push([point.x, point.y]);
    pointLabelsData.push(point.type);
  }

  return {
    image_embed: embedding,
    point_coords: new ort.Tensor('float32', Float32Array.from(pointCoordsData.flat()), [numLabels, numPoints, 2]),
    point_labels: new ort.Tensor('float32', Float32Array.from(pointLabelsData), [numLabels, numPoints]),
    mask_input: new ort.Tensor('float32', new Float32Array(numLabels * 1 * 256 * 256), [numLabels, 1, 256, 256]),
    has_mask_input: new ort.Tensor('float32', new Float32Array([0.0]), [numLabels]),
    high_res_feats_0: new ort.Tensor('float32', new Float32Array(1 * 32 * 256 * 256), [1, 32, 256, 256]),
    high_res_feats_1: new ort.Tensor('float32', new Float32Array(1 * 64 * 128 * 128), [1, 64, 128, 128]),
  };
}

Prediction Process

The predict method runs the decoder model to generate the segmentation mask:

async predict(embedding, inputPoints) {
  const inputs = this.prepareInputs(embedding, inputPoints);
  const results = await this.session.run(inputs);
  return results;
}

Building the user interface (app.js)

Setting up the canvas

We use two HTML canvas elements:

  • Source canvas (sourceCanvas): Displays the uploaded image and interaction points.
  • Mask canvas (maskCanvas): Overlays the segmentation mask.

Handling user interactions

Image upload

When a user uploads an image two things happen:

  • The image is drawn on the sourceCanvas.
  • The encoder generates embeddings from the image.
imageInput.addEventListener('change', async (e) => {
  // Load image and draw on canvas
  // Encode the image
  embedding = await encoder.encode(img);
});

Adding interaction points

Users can click on the image to add positive or negative points:

  • Positive points: Indicate areas to include in the segmentation.
  • Negative points: Indicate areas to exclude.
sourceCanvas.addEventListener('click', async (e) => {
  const x = e.clientX - rect.left;
  const y = e.clientY - rect.top;
  const point = { x: x, y: y, type: isNegative ? 0 : 1 };
  points.push(point);

  // Draw point on canvas
  drawPoint(sourceCtx, point);

  // Run prediction
  const results = await predictor.predict(embedding, points);

  // Draw mask
  drawMaskOnCanvas(maskCanvas, results['masks'], imageWidth, imageHeight);
});

Toggling point types

A button allows users to switch between positive and negative points:

negativeBtn.addEventListener('click', () => {
  isNegative = !isNegative;
  negativeBtn.textContent = `Negative Points: ${isNegative ? 'ON' : 'OFF'}`;
});

Drawing points and masks

Drawing points

Points are drawn on the sourceCanvas to provide visual feedback:

function drawPoint(ctx, point) {
  ctx.fillStyle = point.type === 1 ? 'green' : 'red';
  ctx.beginPath();
  ctx.arc(point.x, point.y, 5, 0, 2 * Math.PI);
  ctx.fill();
}

Drawing masks

Masks are drawn on the maskCanvas to display the segmentation result:

function drawMaskOnCanvas(maskCanvas, maskData, imageWidth, imageHeight) {
  // Process the mask tensor and draw it over the image
}

Running the application

Step-by-step guide

  1. Set up the server: Start your local server to host the model files.
  2. Include scripts: In your HTML file, include ort and your JavaScript modules.
  3. Open the application: Access the HTML file through your browser.
  4. Upload an image: Use the file input to select an image.
  5. Interact with the image: Click on the image to add points and see the segmentation mask update in real-time.

Demonstration

Check out and play around with an interactive demo here.

Conclusion

In this article, we've demonstrated how to run the Segment Anything Model 2 (SAM2) directly in the web browser. By leveraging ONNX Runtime Web and thoughtful implementation of the encoder and decoder, we've created an interactive image segmentation tool that runs entirely on the user-side. 

This approach opens doors to privacy-preserving applications and makes advanced machine learning models more accessible. As web technologies continue to evolve, we can expect even more sophisticated models to run efficiently in the browser.

At Labelbox, we’ve adopted a hybrid strategy, running SAM’s encoder on the server while executing its decoder right in the browser. This approach ensures real-time image segmentation, preserves user privacy, and expands access to cutting-edge machine learning. As web technologies advance, we look forward to delivering even more powerful models efficiently and securely in the browser.

Check out these additional resources: