23 changed files with 766 additions and 0 deletions
After Width: | Height: | Size: 1.9 MiB |
@ -0,0 +1,95 @@
@@ -0,0 +1,95 @@
|
||||
## Segment Anything Simple Web demo |
||||
|
||||
This **front-end only** demo shows how to load a fixed image and `.npy` file of the SAM image embedding, and run the SAM ONNX model in the browser using Web Assembly with mulithreading enabled by `SharedArrayBuffer`, Web Worker, and SIMD128. |
||||
|
||||
<img src="https://github.com/facebookresearch/segment-anything/raw/main/assets/minidemo.gif" width="500"/> |
||||
|
||||
## Run the app |
||||
|
||||
``` |
||||
yarn && yarn start |
||||
``` |
||||
|
||||
Navigate to [`http://localhost:8081/`](http://localhost:8081/) |
||||
|
||||
Move your cursor around to see the mask prediction update in real time. |
||||
|
||||
## Change the image, embedding and ONNX model |
||||
|
||||
In the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) upload the image of your choice and generate and save corresponding embedding. |
||||
|
||||
Initialize the predictor |
||||
|
||||
```python |
||||
checkpoint = "sam_vit_h_4b8939.pth" |
||||
model_type = "vit_h" |
||||
sam = sam_model_registry[model_type](checkpoint=checkpoint) |
||||
sam.to(device='cuda') |
||||
predictor = SamPredictor(sam) |
||||
``` |
||||
|
||||
Set the new image and export the embedding |
||||
|
||||
``` |
||||
image = cv2.imread('src/assets/dogs.jpg') |
||||
predictor.set_image(image) |
||||
image_embedding = predictor.get_image_embedding().cpu().numpy() |
||||
np.save("dogs_embedding.npy", image_embedding) |
||||
``` |
||||
|
||||
Save the new image and embedding in `/assets/data`and update the following paths to the files at the top of`App.tsx`: |
||||
|
||||
```py |
||||
const IMAGE_PATH = "/assets/data/dogs.jpg"; |
||||
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy"; |
||||
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx"; |
||||
``` |
||||
|
||||
Optionally you can also export the ONNX model. Currently the example ONNX model from the notebook is saved at `/model/sam_onnx_quantized_example.onnx`. |
||||
|
||||
**NOTE: if you change the ONNX model by using a new checkpoint you need to also re-export the embedding.** |
||||
|
||||
## ONNX multithreading with SharedArrayBuffer |
||||
|
||||
To use multithreading, the appropriate headers need to be set to create a cross origin isolation state which will enable use of `SharedArrayBuffer` (see this [blog post](https://cloudblogs.microsoft.com/opensource/2021/09/02/onnx-runtime-web-running-your-machine-learning-model-in-browser/) for more details) |
||||
|
||||
The headers below are set in `configs/webpack/dev.js`: |
||||
|
||||
```js |
||||
headers: { |
||||
"Cross-Origin-Opener-Policy": "same-origin", |
||||
"Cross-Origin-Embedder-Policy": "credentialless", |
||||
} |
||||
``` |
||||
|
||||
## Structure of the app |
||||
|
||||
**`App.tsx`** |
||||
|
||||
- Initializes ONNX model |
||||
- Loads image embedding and image |
||||
- Runs the ONNX model based on input prompts |
||||
|
||||
**`Stage.tsx`** |
||||
|
||||
- Handles mouse move interaction to update the ONNX model prompt |
||||
|
||||
**`Tool.tsx`** |
||||
|
||||
- Renders the image and the mask prediction |
||||
|
||||
**`helpers/maskUtils.tsx`** |
||||
|
||||
- Conversion of ONNX model output from array to an HTMLImageElement |
||||
|
||||
**`helpers/onnxModelAPI.tsx`** |
||||
|
||||
- Formats the inputs for the ONNX model |
||||
|
||||
**`helpers/scaleHelper.tsx`** |
||||
|
||||
- Handles image scaling logic for SAM (longest size 1024) |
||||
|
||||
**`hooks/`** |
||||
|
||||
- Handle shared state for the app |
@ -0,0 +1,78 @@
@@ -0,0 +1,78 @@
|
||||
const { resolve } = require("path"); |
||||
const HtmlWebpackPlugin = require("html-webpack-plugin"); |
||||
const FriendlyErrorsWebpackPlugin = require("friendly-errors-webpack-plugin"); |
||||
const CopyPlugin = require("copy-webpack-plugin"); |
||||
const webpack = require("webpack"); |
||||
|
||||
module.exports = { |
||||
entry: "./src/index.tsx", |
||||
resolve: { |
||||
extensions: [".js", ".jsx", ".ts", ".tsx"], |
||||
}, |
||||
output: { |
||||
path: resolve(__dirname, "dist"), |
||||
}, |
||||
module: { |
||||
rules: [ |
||||
{ |
||||
test: /\.mjs$/, |
||||
include: /node_modules/, |
||||
type: "javascript/auto", |
||||
resolve: { |
||||
fullySpecified: false, |
||||
}, |
||||
}, |
||||
{ |
||||
test: [/\.jsx?$/, /\.tsx?$/], |
||||
use: ["ts-loader"], |
||||
exclude: /node_modules/, |
||||
}, |
||||
{ |
||||
test: /\.css$/, |
||||
use: ["style-loader", "css-loader"], |
||||
}, |
||||
{ |
||||
test: /\.(scss|sass)$/, |
||||
use: ["style-loader", "css-loader", "postcss-loader"], |
||||
}, |
||||
{ |
||||
test: /\.(jpe?g|png|gif|svg)$/i, |
||||
use: [ |
||||
"file-loader?hash=sha512&digest=hex&name=img/[contenthash].[ext]", |
||||
"image-webpack-loader?bypassOnDebug&optipng.optimizationLevel=7&gifsicle.interlaced=false", |
||||
], |
||||
}, |
||||
{ |
||||
test: /\.(woff|woff2|ttf)$/, |
||||
use: { |
||||
loader: "url-loader", |
||||
}, |
||||
}, |
||||
], |
||||
}, |
||||
plugins: [ |
||||
new CopyPlugin({ |
||||
patterns: [ |
||||
{ |
||||
from: "node_modules/onnxruntime-web/dist/*.wasm", |
||||
to: "[name][ext]", |
||||
}, |
||||
{ |
||||
from: "model", |
||||
to: "model", |
||||
}, |
||||
{ |
||||
from: "src/assets", |
||||
to: "assets", |
||||
}, |
||||
], |
||||
}), |
||||
new HtmlWebpackPlugin({ |
||||
template: "./src/assets/index.html", |
||||
}), |
||||
new FriendlyErrorsWebpackPlugin(), |
||||
new webpack.ProvidePlugin({ |
||||
process: "process/browser", |
||||
}), |
||||
], |
||||
}; |
@ -0,0 +1,19 @@
@@ -0,0 +1,19 @@
|
||||
// development config
|
||||
const { merge } = require("webpack-merge"); |
||||
const commonConfig = require("./common"); |
||||
|
||||
module.exports = merge(commonConfig, { |
||||
mode: "development", |
||||
devServer: { |
||||
hot: true, // enable HMR on the server
|
||||
open: true, |
||||
// These headers enable the cross origin isolation state
|
||||
// needed to enable use of SharedArrayBuffer for ONNX
|
||||
// multithreading.
|
||||
headers: { |
||||
"Cross-Origin-Opener-Policy": "same-origin", |
||||
"Cross-Origin-Embedder-Policy": "credentialless", |
||||
}, |
||||
}, |
||||
devtool: "cheap-module-source-map", |
||||
}); |
@ -0,0 +1,16 @@
@@ -0,0 +1,16 @@
|
||||
// production config
|
||||
const { merge } = require("webpack-merge"); |
||||
const { resolve } = require("path"); |
||||
const Dotenv = require("dotenv-webpack"); |
||||
const commonConfig = require("./common"); |
||||
|
||||
module.exports = merge(commonConfig, { |
||||
mode: "production", |
||||
output: { |
||||
filename: "js/bundle.[contenthash].min.js", |
||||
path: resolve(__dirname, "../../dist"), |
||||
publicPath: "/", |
||||
}, |
||||
devtool: "source-map", |
||||
plugins: [new Dotenv()], |
||||
}); |
@ -0,0 +1,64 @@
@@ -0,0 +1,64 @@
|
||||
{ |
||||
"name": "se-demo", |
||||
"version": "0.1.0", |
||||
"license": "MIT", |
||||
"scripts": { |
||||
"build": "yarn run clean-dist && webpack --config=configs/webpack/prod.js && mv dist/*.wasm dist/js && cp -R dataset dist", |
||||
"clean-dist": "rimraf dist/*", |
||||
"lint": "eslint './src/**/*.{js,ts,tsx}' --quiet", |
||||
"start": "yarn run start-dev", |
||||
"test": "yarn run start-model-test", |
||||
"start-dev": "webpack serve --config=configs/webpack/dev.js" |
||||
}, |
||||
"devDependencies": { |
||||
"@babel/core": "^7.18.13", |
||||
"@babel/preset-env": "^7.18.10", |
||||
"@babel/preset-react": "^7.18.6", |
||||
"@babel/preset-typescript": "^7.18.6", |
||||
"@pmmmwh/react-refresh-webpack-plugin": "^0.5.7", |
||||
"@testing-library/react": "^13.3.0", |
||||
"@types/node": "^18.7.13", |
||||
"@types/react": "^18.0.17", |
||||
"@types/react-dom": "^18.0.6", |
||||
"@types/underscore": "^1.11.4", |
||||
"@typescript-eslint/eslint-plugin": "^5.35.1", |
||||
"@typescript-eslint/parser": "^5.35.1", |
||||
"babel-loader": "^8.2.5", |
||||
"copy-webpack-plugin": "^11.0.0", |
||||
"css-loader": "^6.7.1", |
||||
"dotenv": "^16.0.2", |
||||
"dotenv-webpack": "^8.0.1", |
||||
"eslint": "^8.22.0", |
||||
"eslint-plugin-react": "^7.31.0", |
||||
"file-loader": "^6.2.0", |
||||
"fork-ts-checker-webpack-plugin": "^7.2.13", |
||||
"friendly-errors-webpack-plugin": "^1.7.0", |
||||
"html-webpack-plugin": "^5.5.0", |
||||
"image-webpack-loader": "^8.1.0", |
||||
"postcss-loader": "^7.0.1", |
||||
"postcss-preset-env": "^7.8.0", |
||||
"process": "^0.11.10", |
||||
"rimraf": "^3.0.2", |
||||
"sass": "^1.54.5", |
||||
"sass-loader": "^13.0.2", |
||||
"style-loader": "^3.3.1", |
||||
"tailwindcss": "^3.1.8", |
||||
"ts-loader": "^9.3.1", |
||||
"typescript": "^4.8.2", |
||||
"webpack": "^5.74.0", |
||||
"webpack-cli": "^4.10.0", |
||||
"webpack-dev-server": "^4.10.0", |
||||
"webpack-dotenv-plugin": "^2.1.0", |
||||
"webpack-merge": "^5.8.0" |
||||
}, |
||||
"dependencies": { |
||||
"konva": "^8.3.12", |
||||
"npyjs": "^0.4.0", |
||||
"onnxruntime-web": "^1.14.0", |
||||
"react": "^18.2.0", |
||||
"react-dom": "^18.2.0", |
||||
"react-konva": "^18.2.1", |
||||
"underscore": "^1.13.6", |
||||
"react-refresh": "^0.14.0" |
||||
} |
||||
} |
@ -0,0 +1,4 @@
@@ -0,0 +1,4 @@
|
||||
const tailwindcss = require("tailwindcss"); |
||||
module.exports = { |
||||
plugins: ["postcss-preset-env", 'tailwindcss/nesting', tailwindcss], |
||||
}; |
@ -0,0 +1,124 @@
@@ -0,0 +1,124 @@
|
||||
import { InferenceSession, Tensor } from "onnxruntime-web"; |
||||
import React, { useContext, useEffect, useState } from "react"; |
||||
import "./assets/scss/App.scss"; |
||||
import { handleImageScale } from "./components/helpers/scaleHelper"; |
||||
import { modelScaleProps } from "./components/helpers/Interfaces"; |
||||
import { onnxMaskToImage } from "./components/helpers/maskUtils"; |
||||
import { modelData } from "./components/helpers/onnxModelAPI"; |
||||
import Stage from "./components/Stage"; |
||||
import AppContext from "./components/hooks/createContext"; |
||||
const ort = require("onnxruntime-web"); |
||||
/* @ts-ignore */ |
||||
import npyjs from "npyjs"; |
||||
|
||||
// Define image, embedding and model paths
|
||||
const IMAGE_PATH = "/assets/data/dogs.jpg"; |
||||
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy"; |
||||
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx"; |
||||
|
||||
const App = () => { |
||||
const { |
||||
clicks: [clicks], |
||||
image: [, setImage], |
||||
maskImg: [, setMaskImg], |
||||
} = useContext(AppContext)!; |
||||
const [model, setModel] = useState<InferenceSession | null>(null); // ONNX model
|
||||
const [tensor, setTensor] = useState<Tensor | null>(null); // Image embedding tensor
|
||||
|
||||
// The ONNX model expects the input to be rescaled to 1024.
|
||||
// The modelScale state variable keeps track of the scale values.
|
||||
const [modelScale, setModelScale] = useState<modelScaleProps | null>(null); |
||||
|
||||
// Initialize the ONNX model. load the image, and load the SAM
|
||||
// pre-computed image embedding
|
||||
useEffect(() => { |
||||
// Initialize the ONNX model
|
||||
const initModel = async () => { |
||||
try { |
||||
if (MODEL_DIR === undefined) return; |
||||
const URL: string = MODEL_DIR; |
||||
const model = await InferenceSession.create(URL); |
||||
setModel(model); |
||||
} catch (e) { |
||||
console.log(e); |
||||
} |
||||
}; |
||||
initModel(); |
||||
|
||||
// Load the image
|
||||
const url = new URL(IMAGE_PATH, location.origin); |
||||
loadImage(url); |
||||
|
||||
// Load the Segment Anything pre-computed embedding
|
||||
Promise.resolve(loadNpyTensor(IMAGE_EMBEDDING, "float32")).then( |
||||
(embedding) => setTensor(embedding) |
||||
); |
||||
}, []); |
||||
|
||||
const loadImage = async (url: URL) => { |
||||
try { |
||||
const img = new Image(); |
||||
img.src = url.href; |
||||
img.onload = () => { |
||||
const { height, width, samScale } = handleImageScale(img); |
||||
setModelScale({ |
||||
height: height, // original image height
|
||||
width: width, // original image width
|
||||
samScale: samScale, // scaling factor for image which has been resized to longest side 1024
|
||||
}); |
||||
img.width = width; |
||||
img.height = height; |
||||
setImage(img); |
||||
}; |
||||
} catch (error) { |
||||
console.log(error); |
||||
} |
||||
}; |
||||
|
||||
// Decode a Numpy file into a tensor.
|
||||
const loadNpyTensor = async (tensorFile: string, dType: string) => { |
||||
let npLoader = new npyjs(); |
||||
const npArray = await npLoader.load(tensorFile); |
||||
const tensor = new ort.Tensor(dType, npArray.data, npArray.shape); |
||||
return tensor; |
||||
}; |
||||
|
||||
// Run the ONNX model every time clicks has changed
|
||||
useEffect(() => { |
||||
runONNX(); |
||||
}, [clicks]); |
||||
|
||||
const runONNX = async () => { |
||||
try { |
||||
if ( |
||||
model === null || |
||||
clicks === null || |
||||
tensor === null || |
||||
modelScale === null |
||||
) |
||||
return; |
||||
else { |
||||
// Preapre the model input in the correct format for SAM.
|
||||
// The modelData function is from onnxModelAPI.tsx.
|
||||
const feeds = modelData({ |
||||
clicks, |
||||
tensor, |
||||
modelScale, |
||||
}); |
||||
if (feeds === undefined) return; |
||||
// Run the SAM ONNX model with the feeds returned from modelData()
|
||||
const results = await model.run(feeds); |
||||
const output = results[model.outputNames[0]]; |
||||
// The predicted mask returned from the ONNX model is an array which is
|
||||
// rendered as an HTML image using onnxMaskToImage() from maskUtils.tsx.
|
||||
setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3])); |
||||
} |
||||
} catch (e) { |
||||
console.log(e); |
||||
} |
||||
}; |
||||
|
||||
return <Stage />; |
||||
}; |
||||
|
||||
export default App; |
After Width: | Height: | Size: 438 KiB |
@ -0,0 +1,18 @@
@@ -0,0 +1,18 @@
|
||||
<!DOCTYPE html> |
||||
<html lang="en" dir="ltr" prefix="og: https://ogp.me/ns#" class="w-full h-full"> |
||||
<head> |
||||
<meta charset="utf-8" /> |
||||
<meta |
||||
name="viewport" |
||||
content="width=device-width, initial-scale=1, shrink-to-fit=no" |
||||
/> |
||||
<title>Segment Anything Demo</title> |
||||
|
||||
<!-- Meta Tags --> |
||||
<meta property="og:type" content="website" /> |
||||
<meta property="og:title" content="Segment Anything Demo" /> |
||||
</head> |
||||
<body class="w-full h-full"> |
||||
<div id="root" class="w-full h-full"></div> |
||||
</body> |
||||
</html> |
@ -0,0 +1,3 @@
@@ -0,0 +1,3 @@
|
||||
@tailwind base; |
||||
@tailwind components; |
||||
@tailwind utilities; |
@ -0,0 +1,43 @@
@@ -0,0 +1,43 @@
|
||||
import React, { useContext } from "react"; |
||||
import * as _ from "underscore"; |
||||
import Tool from "./Tool"; |
||||
import { modelInputProps } from "./helpers/Interfaces"; |
||||
import AppContext from "./hooks/createContext"; |
||||
|
||||
const Stage = () => { |
||||
const { |
||||
clicks: [, setClicks], |
||||
image: [image], |
||||
} = useContext(AppContext)!; |
||||
|
||||
const getClick = (x: number, y: number): modelInputProps => { |
||||
const clickType = 1; |
||||
return { x, y, clickType }; |
||||
}; |
||||
|
||||
// Get mouse position and scale the (x, y) coordinates back to the natural
|
||||
// scale of the image. Update the state of clicks with setClicks to trigger
|
||||
// the ONNX model to run and generate a new mask via a useEffect in App.tsx
|
||||
const handleMouseMove = _.throttle((e: any) => { |
||||
let el = e.nativeEvent.target; |
||||
const rect = el.getBoundingClientRect(); |
||||
let x = e.clientX - rect.left; |
||||
let y = e.clientY - rect.top; |
||||
const imageScale = image ? image.width / el.offsetWidth : 1; |
||||
x *= imageScale; |
||||
y *= imageScale; |
||||
const click = getClick(x, y); |
||||
if (click) setClicks([click]); |
||||
}, 15); |
||||
|
||||
const flexCenterClasses = "flex items-center justify-center"; |
||||
return ( |
||||
<div className={`${flexCenterClasses} w-full h-full`}> |
||||
<div className={`${flexCenterClasses} relative w-[90%] h-[90%]`}> |
||||
<Tool handleMouseMove={handleMouseMove} /> |
||||
</div> |
||||
</div> |
||||
); |
||||
}; |
||||
|
||||
export default Stage; |
@ -0,0 +1,67 @@
@@ -0,0 +1,67 @@
|
||||
import React, { useContext, useEffect, useState } from "react"; |
||||
import AppContext from "./hooks/createContext"; |
||||
import { ToolProps } from "./helpers/Interfaces"; |
||||
import * as _ from "underscore"; |
||||
|
||||
const Tool = ({ handleMouseMove }: ToolProps) => { |
||||
const { |
||||
image: [image], |
||||
maskImg: [maskImg, setMaskImg], |
||||
} = useContext(AppContext)!; |
||||
|
||||
// Determine if we should shrink or grow the images to match the
|
||||
// width or the height of the page and setup a ResizeObserver to
|
||||
// monitor changes in the size of the page
|
||||
const [shouldFitToWidth, setShouldFitToWidth] = useState(true); |
||||
const bodyEl = document.body; |
||||
const fitToPage = () => { |
||||
if (!image) return; |
||||
const imageAspectRatio = image.width / image.height; |
||||
const screenAspectRatio = window.innerWidth / window.innerHeight; |
||||
setShouldFitToWidth(imageAspectRatio > screenAspectRatio); |
||||
}; |
||||
const resizeObserver = new ResizeObserver((entries) => { |
||||
for (const entry of entries) { |
||||
if (entry.target === bodyEl) { |
||||
fitToPage(); |
||||
} |
||||
} |
||||
}); |
||||
useEffect(() => { |
||||
fitToPage(); |
||||
resizeObserver.observe(bodyEl); |
||||
return () => { |
||||
resizeObserver.unobserve(bodyEl); |
||||
}; |
||||
}, [image]); |
||||
|
||||
const imageClasses = ""; |
||||
const maskImageClasses = `absolute opacity-40 pointer-events-none`; |
||||
|
||||
// Render the image and the predicted mask image on top
|
||||
return ( |
||||
<> |
||||
{image && ( |
||||
<img |
||||
onMouseMove={handleMouseMove} |
||||
onMouseOut={() => _.defer(() => setMaskImg(null))} |
||||
onTouchStart={handleMouseMove} |
||||
src={image.src} |
||||
className={`${ |
||||
shouldFitToWidth ? "w-full" : "h-full" |
||||
} ${imageClasses}`}
|
||||
></img> |
||||
)} |
||||
{maskImg && ( |
||||
<img |
||||
src={maskImg.src} |
||||
className={`${ |
||||
shouldFitToWidth ? "w-full" : "h-full" |
||||
} ${maskImageClasses}`}
|
||||
></img> |
||||
)} |
||||
</> |
||||
); |
||||
}; |
||||
|
||||
export default Tool; |
@ -0,0 +1,23 @@
@@ -0,0 +1,23 @@
|
||||
import { Tensor } from "onnxruntime-web"; |
||||
|
||||
export interface modelScaleProps { |
||||
samScale: number; |
||||
height: number; |
||||
width: number; |
||||
} |
||||
|
||||
export interface modelInputProps { |
||||
x: number; |
||||
y: number; |
||||
clickType: number; |
||||
} |
||||
|
||||
export interface modeDataProps { |
||||
clicks?: Array<modelInputProps>; |
||||
tensor: Tensor; |
||||
modelScale: modelScaleProps; |
||||
} |
||||
|
||||
export interface ToolProps { |
||||
handleMouseMove: (e: any) => void; |
||||
} |
@ -0,0 +1,43 @@
@@ -0,0 +1,43 @@
|
||||
// Functions for handling mask output from the ONNX model
|
||||
|
||||
// Convert the onnx model mask prediction to ImageData
|
||||
function arrayToImageData(input: any, width: number, height: number) { |
||||
const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color
|
||||
const arr = new Uint8ClampedArray(4 * width * height).fill(0); |
||||
for (let i = 0; i < input.length; i++) { |
||||
|
||||
// Threshold the onnx model mask prediction at 0.0
|
||||
// This is equivalent to thresholding the mask using predictor.model.mask_threshold
|
||||
// in python
|
||||
if (input[i] > 0.0) { |
||||
arr[4 * i + 0] = r; |
||||
arr[4 * i + 1] = g; |
||||
arr[4 * i + 2] = b; |
||||
arr[4 * i + 3] = a; |
||||
} |
||||
} |
||||
return new ImageData(arr, height, width); |
||||
} |
||||
|
||||
// Use a Canvas element to produce an image from ImageData
|
||||
function imageDataToImage(imageData: ImageData) { |
||||
const canvas = imageDataToCanvas(imageData); |
||||
const image = new Image(); |
||||
image.src = canvas.toDataURL(); |
||||
return image; |
||||
} |
||||
|
||||
// Canvas elements can be created from ImageData
|
||||
function imageDataToCanvas(imageData: ImageData) { |
||||
const canvas = document.createElement("canvas"); |
||||
const ctx = canvas.getContext("2d"); |
||||
canvas.width = imageData.width; |
||||
canvas.height = imageData.height; |
||||
ctx?.putImageData(imageData, 0, 0); |
||||
return canvas; |
||||
} |
||||
|
||||
// Convert the onnx model mask output to an HTMLImageElement
|
||||
export function onnxMaskToImage(input: any, width: number, height: number) { |
||||
return imageDataToImage(arrayToImageData(input, width, height)); |
||||
} |
@ -0,0 +1,65 @@
@@ -0,0 +1,65 @@
|
||||
import { Tensor } from "onnxruntime-web"; |
||||
import { modeDataProps } from "./Interfaces"; |
||||
|
||||
const modelData = ({ clicks, tensor, modelScale }: modeDataProps) => { |
||||
const imageEmbedding = tensor; |
||||
let pointCoords; |
||||
let pointLabels; |
||||
let pointCoordsTensor; |
||||
let pointLabelsTensor; |
||||
|
||||
// Check there are input click prompts
|
||||
if (clicks) { |
||||
let n = clicks.length; |
||||
|
||||
// If there is no box input, a single padding point with
|
||||
// label -1 and coordinates (0.0, 0.0) should be concatenated
|
||||
// so initialize the array to support (n + 1) points.
|
||||
pointCoords = new Float32Array(2 * (n + 1)); |
||||
pointLabels = new Float32Array(n + 1); |
||||
|
||||
// Add clicks and scale to what SAM expects
|
||||
for (let i = 0; i < n; i++) { |
||||
pointCoords[2 * i] = clicks[i].x * modelScale.samScale; |
||||
pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale; |
||||
pointLabels[i] = clicks[i].clickType; |
||||
} |
||||
|
||||
// Add in the extra point/label when only clicks and no box
|
||||
// The extra point is at (0, 0) with label -1
|
||||
pointCoords[2 * n] = 0.0; |
||||
pointCoords[2 * n + 1] = 0.0; |
||||
pointLabels[n] = -1.0; |
||||
|
||||
// Create the tensor
|
||||
pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]); |
||||
pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]); |
||||
} |
||||
const imageSizeTensor = new Tensor("float32", [ |
||||
modelScale.height, |
||||
modelScale.width, |
||||
]); |
||||
|
||||
if (pointCoordsTensor === undefined || pointLabelsTensor === undefined) |
||||
return; |
||||
|
||||
// There is no previous mask, so default to an empty tensor
|
||||
const maskInput = new Tensor( |
||||
"float32", |
||||
new Float32Array(256 * 256), |
||||
[1, 1, 256, 256] |
||||
); |
||||
// There is no previous mask, so default to 0
|
||||
const hasMaskInput = new Tensor("float32", [0]); |
||||
|
||||
return { |
||||
image_embeddings: imageEmbedding, |
||||
point_coords: pointCoordsTensor, |
||||
point_labels: pointLabelsTensor, |
||||
orig_im_size: imageSizeTensor, |
||||
mask_input: maskInput, |
||||
has_mask_input: hasMaskInput, |
||||
}; |
||||
}; |
||||
|
||||
export { modelData }; |
@ -0,0 +1,12 @@
@@ -0,0 +1,12 @@
|
||||
|
||||
// Helper function for handling image scaling needed for SAM
|
||||
const handleImageScale = (image: HTMLImageElement) => { |
||||
// Input images to SAM must be resized so the longest side is 1024
|
||||
const LONG_SIDE_LENGTH = 1024; |
||||
let w = image.naturalWidth; |
||||
let h = image.naturalHeight; |
||||
const samScale = LONG_SIDE_LENGTH / Math.max(h, w); |
||||
return { height: h, width: w, samScale }; |
||||
}; |
||||
|
||||
export { handleImageScale }; |
@ -0,0 +1,25 @@
@@ -0,0 +1,25 @@
|
||||
import React, { useState } from "react"; |
||||
import { modelInputProps } from "../helpers/Interfaces"; |
||||
import AppContext from "./createContext"; |
||||
|
||||
const AppContextProvider = (props: { |
||||
children: React.ReactElement<any, string | React.JSXElementConstructor<any>>; |
||||
}) => { |
||||
const [clicks, setClicks] = useState<Array<modelInputProps> | null>(null); |
||||
const [image, setImage] = useState<HTMLImageElement | null>(null); |
||||
const [maskImg, setMaskImg] = useState<HTMLImageElement | null>(null); |
||||
|
||||
return ( |
||||
<AppContext.Provider |
||||
value={{ |
||||
clicks: [clicks, setClicks], |
||||
image: [image, setImage], |
||||
maskImg: [maskImg, setMaskImg], |
||||
}} |
||||
> |
||||
{props.children} |
||||
</AppContext.Provider> |
||||
); |
||||
}; |
||||
|
||||
export default AppContextProvider; |
@ -0,0 +1,21 @@
@@ -0,0 +1,21 @@
|
||||
import { createContext } from "react"; |
||||
import { modelInputProps } from "../helpers/Interfaces"; |
||||
|
||||
interface contextProps { |
||||
clicks: [ |
||||
clicks: modelInputProps[] | null, |
||||
setClicks: (e: modelInputProps[] | null) => void |
||||
]; |
||||
image: [ |
||||
image: HTMLImageElement | null, |
||||
setImage: (e: HTMLImageElement | null) => void |
||||
]; |
||||
maskImg: [ |
||||
maskImg: HTMLImageElement | null, |
||||
setMaskImg: (e: HTMLImageElement | null) => void |
||||
]; |
||||
} |
||||
|
||||
const AppContext = createContext<contextProps | null>(null); |
||||
|
||||
export default AppContext; |
@ -0,0 +1,11 @@
@@ -0,0 +1,11 @@
|
||||
import * as React from "react"; |
||||
import { createRoot } from "react-dom/client"; |
||||
import AppContextProvider from "./components/hooks/context"; |
||||
import App from "./App"; |
||||
const container = document.getElementById("root"); |
||||
const root = createRoot(container!); |
||||
root.render( |
||||
<AppContextProvider> |
||||
<App/> |
||||
</AppContextProvider> |
||||
); |
@ -0,0 +1,6 @@
@@ -0,0 +1,6 @@
|
||||
/** @type {import('tailwindcss').Config} */ |
||||
module.exports = { |
||||
content: ["./src/**/*.{html,js,tsx}"], |
||||
theme: {}, |
||||
plugins: [], |
||||
}; |
@ -0,0 +1,24 @@
@@ -0,0 +1,24 @@
|
||||
{ |
||||
"compilerOptions": { |
||||
"lib": ["dom", "dom.iterable", "esnext"], |
||||
"allowJs": true, |
||||
"skipLibCheck": true, |
||||
"strict": true, |
||||
"forceConsistentCasingInFileNames": true, |
||||
"noEmit": false, |
||||
"esModuleInterop": true, |
||||
"module": "esnext", |
||||
"moduleResolution": "node", |
||||
"resolveJsonModule": true, |
||||
"isolatedModules": true, |
||||
"jsx": "react", |
||||
"incremental": true, |
||||
"target": "ESNext", |
||||
"useDefineForClassFields": true, |
||||
"allowSyntheticDefaultImports": true, |
||||
"outDir": "./dist/", |
||||
"sourceMap": true |
||||
}, |
||||
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", "src"], |
||||
"exclude": ["node_modules"] |
||||
} |
Loading…
Reference in new issue