Object Detection
In this tutorial, you will implement a function that takes in an image and runs it through an object detection model to find what objects are in the image and where.
Viewing this Demo
In order to view this demo download the PlayTorch app.
Preview
If you want a sneak peek at what you'll be building, run this Snack by scanning the QR code in the PlayTorch app!
Overview
We'll go through the following steps:
This tutorial will focus on the machine learning aspects of creating this object detection prototype. In order to focus on machine learning, we will provide the UI code for you.
- Create a copy of the starter Expo Snack
- Run the project in the PlayTorch app
- Create an image handler function
- Convert image to tensor ready for inference
- Load model and run inference
- Format model output
- Display results
Copying the Starter Snack
Click the button below to open your own copy of the starter Expo Snack for this tutorial. It will open a new tab to an Expo Snack, the in-browser code editor we will use for this tutorial that allows us to test our code in the PlayTorch app.
Click to Copy Starter SnackLet's run through a quick overview of everything that you can find in the starter snack.
screens/
├─ CameraScreen.js
├─ LoadingScreen.js
├─ ResultsScreen.js
App.js
CoCoClasses.json
ObjectDetector.js
package.json
screens/
The screens
folder contains three files -- one for each screen of the object detection prototype we are building.
CameraScreen.js
- A full screen camera view to capture the image we will run through our object detection modelLoadingScreen.js
- A loading screen that let's the user know the image is being processedResultsScreen.js
- A screen showing the image and the bounding boxes of the detected objects
App.js
This file manages which screen to show and helps pass data between them.
CoCoClasses.json
This file contains the labels for the different classes of objects that our model has been trained to detect.
ObjectDetector.js
This file will contain the code that actually runs our machine learning model.
package.json
This file contains a list of packages that our project depends on so the app knows to download them.
Running the project in the PlayTorch app
Open the PlayTorch app and from the home screen, tap "Scan QR Code".
If you have never done this before, it will ask for camera permissions. Grant the app camera permissions and scan the QR code from the right side of the Snack window.
If you haven't made any changes to the snack, you should see a screen that looks like this (with a different view in your camera of course):
In this starter state, nothing happens when you press the camera's capture button. Let's fix that.
Create an image handler function
In the initial state of the codebase, if we look into our CameraScreen.js
file, we will notice:
- It is expecting an
onCapture
function as a prop (line 6) - It passes the
onCapture
prop to the<Camera />
component provided byreact-native-pytorch-core
(line 12)
import * as React from 'react';
import {useRef} from 'react';
import {StyleSheet} from 'react-native';
import {Camera} from 'react-native-pytorch-core';
export default function CameraScreen({onCapture}) {
const cameraRef = useRef(null);
return (
<Camera
ref={cameraRef}
style={styles.camera}
onCapture={onCapture}
targetResolution={{width: 1080, height: 1920}}
/>
);
}
const styles = StyleSheet.create({
camera: {width: '100%', height: '100%'},
});
However, currently the code in our App.js
file doesn't pass anything to our <CameraScreen />
component. Make the changes in the code below to your App.js
file to create a function for handling captured images.
Here is a quick summary of the changes that we are making:
- Create a new async function called
handleImage
that does the following:- Log the image so we can inspect it
- Release the image so it doesn't leak memory
- Pass the new
handleImage
property to theonCapture
prop
import * as React from 'react';
import {useCallback, useRef, useState} from 'react';
import {
ActivityIndicator,
Button,
SafeAreaView,
StyleSheet,
Text,
TouchableOpacity,
View,
} from 'react-native';
import {Camera, Canvas, Image} from 'react-native-pytorch-core';
import detectObjects from './ObjectDetector';
import CameraScreen from './screens/CameraScreen';
import LoadingScreen from './screens/LoadingScreen';
import ResultsScreen from './screens/ResultsScreen';
const ScreenStates = {
CAMERA: 0,
LOADING: 1,
RESULTS: 2,
};
export default function ObjectDetectionExample() {
const [image, setImage] = useState(null);
const [boundingBoxes, setBoundingBoxes] = useState(null);
const [screenState, setScreenState] = useState(ScreenStates.CAMERA);
// Handle the reset button and return to the camera capturing mode
const handleReset = useCallback(async () => {
setScreenState(ScreenStates.CAMERA);
if (image != null) {
await image.release();
}
setImage(null);
setBoundingBoxes(null);
}, [image, setScreenState]);
// This handler function handles the camera's capture event
async function handleImage(capturedImage) {
console.log('Captured image', capturedImage);
capturedImage.release();
}
return (
<SafeAreaView style={styles.container}>
{screenState === ScreenStates.CAMERA && (
<CameraScreen onCapture={handleImage} />
)}
{screenState === ScreenStates.LOADING && <LoadingScreen />}
{screenState === ScreenStates.RESULTS && (
<ResultsScreen
image={image}
boundingBoxes={boundingBoxes}
onReset={handleReset}
/>
)}
</SafeAreaView>
);
}
const styles = StyleSheet.create({
container: {
flex: 1,
},
});
Run the updated code on the PlayTorch app, press the camera capture button and then check the logs to see what it output.
Open the logs in the Snack window by clicking the settings gear icon at the bottom of the window, enabling the Panel, and clicking the logs tab of the newly opened panel.
You should see something like the following, but with a different ID:
Captured image {"ID":"ABEF4518-A791-45C3-BA5C-C424F96C58EC"}
Now that we have access to the images that we are capturing, let's write some code to convert them into a format that our ML model can understand.
Convert image to tensor ready for inference
In order to keep our code clean and simple to understand, we will seperate all of the code that interacts with the ML model into its own file. That file already exists and is called ObjectDetector.js
.
Open that file and notice it is a rather bare file. All that it does at the moment is export an async function called detectObjects
that expects an image
as a parameter.
export default async function detectObjects(image) {}
The model we will be using for object detection is called DETR, which stands for Detection Transformer.
The DETR model only knows how to process data in a very specific format. We will be converting our image into a tensor (a multi-dimensional array) and then rearranging the data in that tensor to be formatted just how DETR expects it.
The code below does all the transformations needed to the image. Here's a summary of what is going on:
- Import
media
,torch
, andtorchvision
from the PlayTorch SDK (react-native-pytorch-core
) - Create variable
T
that is a shortcut way of accessingtorchvisions.transforms
since we will be using the transforms a lot - Update the empty
detectObjects
function to do the following:- Store the dimensions of the
image
toimageWidth
andimageHeight
variables - Create a
blob
(a raw memory store of file contents, which in this case contains the RGB byte values) from theimage
parameter using themedia.toBlob
function - Create a
tensor
with the raw data in the newly createdblob
and with the shape defined by the array[imageHeight, imageWidth, 3]
(the size of the image and 3 for RGB) - Rearange the order of the tensor so the channels (RGB) are first, then the height and width
- Transform all the values in the tensor to be floating point numbers between 0 and 1 instead of byte values between 0 and 255
- Create a resize transformation with size 800 called
resize
- Apply the
resize
transformation to the tensor - Create a normalization transformation called
normalize
with mean and standard deviation values that the model was trained on - Apply the
nomralize
transformation to the tensor - Add an extra leading dimension to the tensor with
tensor.unsqueeze
- Log the final shape to make sure it matches the format the model needs which is
[1, 3, 800, 800]
- Store the dimensions of the
import {media, torch, torchvision} from 'react-native-pytorch-core';
const T = torchvision.transforms;
export default async function detectObjects(image) {
// Get image width and height
const imageWidth = image.getWidth();
const imageHeight = image.getHeight();
// Convert image to blob, which is a byte representation of the image
// in the format height (H), width (W), and channels (C), or HWC for short
const blob = media.toBlob(image);
// Get a tensor from image the blob and also define in what format
// the image blob is.
let tensor = torch.fromBlob(blob, [imageHeight, imageWidth, 3]);
// Rearrange the tensor shape to be [CHW]
tensor = tensor.permute([2, 0, 1]);
// Divide the tensor values by 255 to get values between [0, 1]
tensor = tensor.div(255);
// Resize the image tensor to 3 x min(height, 800) x min(width, 800)
const resize = T.resize(800);
tensor = resize(tensor);
// Normalize the tensor image with mean and standard deviation
const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
tensor = normalize(tensor);
// Unsqueeze adds 1 leading dimension to the tensor
const formattedInputTensor = tensor.unsqueeze(0);
console.log(formattedInputTensor);
}
In order for us to be able to run and test this, let's make a quick update to the App.js
file to make the handleImage
function call our updated detectObjects
function.
import * as React from 'react';
import {useCallback, useRef, useState} from 'react';
import {
ActivityIndicator,
Button,
SafeAreaView,
StyleSheet,
Text,
TouchableOpacity,
View,
} from 'react-native';
import {Camera, Canvas, Image} from 'react-native-pytorch-core';
import detectObjects from './ObjectDetector';
import CameraScreen from './screens/CameraScreen';
import LoadingScreen from './screens/LoadingScreen';
import ResultsScreen from './screens/ResultsScreen';
const ScreenStates = {
CAMERA: 0,
LOADING: 1,
RESULTS: 2,
};
export default function ObjectDetectionExample() {
const [image, setImage] = useState(null);
const [boundingBoxes, setBoundingBoxes] = useState(null);
const [screenState, setScreenState] = useState(ScreenStates.CAMERA);
// Handle the reset button and return to the camera capturing mode
const handleReset = useCallback(async () => {
setScreenState(ScreenStates.CAMERA);
if (image != null) {
await image.release();
}
setImage(null);
setBoundingBoxes(null);
}, [image, setScreenState]);
// This handler function handles the camera's capture event
async function handleImage(capturedImage) {
await detectObjects(capturedImage);
capturedImage.release();
}
return (
<SafeAreaView style={styles.container}>
{screenState === ScreenStates.CAMERA && (
<CameraScreen onCapture={handleImage} />
)}
{screenState === ScreenStates.LOADING && <LoadingScreen />}
{screenState === ScreenStates.RESULTS && (
<ResultsScreen
image={image}
boundingBoxes={boundingBoxes}
onReset={handleReset}
/>
)}
</SafeAreaView>
);
}
const styles = StyleSheet.create({
container: {
flex: 1,
},
});
Now that we are calling the updated detectObjects
function, let's test it by running the snack in the PlayTorch app.
When you press the capture button now, you should see something similar to the following in the logs:
▼{dtype:"float32",shape:[…]}
dtype:"float32"
►shape:[1,3,1422,800]
With our tensor in the right shape and ready for inference, let's load a model and run it!
Load model and run inference
Now we get to actually add some ML to our app! Let's walk through the updates we are making to the ObjectDetector.js
file in the block below.
- Add the
MobileModel
to our imports from the PlayTorch SDK (react-native-pytorch-core
) - Create a variable called
MODEL_URL
and set it equal to the download link for the DETR model (https://github.com/facebookresearch/playtorch/releases/download/v0.1.0/detr_resnet50.ptl) - Create a variable called
model
and initially set it tonull
. We put this variable outside of ourdetectObjects
function so we only have to load it once. - Check if the
model
is stillnull
and if it is, download the model file from theMODEL_URL
and then load themodel
into memory with the newly downloaded file - Run the DETR model on our
formattedInputTensor
by calling themodel.forward
function and store the result in a variable calledoutput
- Log the
output
to see what our model has generated
import {
media,
MobileModel,
torch,
torchvision,
} from 'react-native-pytorch-core';
const T = torchvision.transforms;
const MODEL_URL =
'https://github.com/facebookresearch/playtorch/releases/download/v0.1.0/detr_resnet50.ptl';
let model = null;
export default async function detectObjects(image) {
// Get image width and height
const imageWidth = image.getWidth();
const imageHeight = image.getHeight();
// Convert image to blob, which is a byte representation of the image
// in the format height (H), width (W), and channels (C), or HWC for short
const blob = media.toBlob(image);
// Get a tensor from image the blob and also define in what format
// the image blob is.
let tensor = torch.fromBlob(blob, [imageHeight, imageWidth, 3]);
// Rearrange the tensor shape to be [CHW]
tensor = tensor.permute([2, 0, 1]);
// Divide the tensor values by 255 to get values between [0, 1]
tensor = tensor.div(255);
// Resize the image tensor to 3 x min(height, 800) x min(width, 800)
const resize = T.resize(800);
tensor = resize(tensor);
// Normalize the tensor image with mean and standard deviation
const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
tensor = normalize(tensor);
// Unsqueeze adds 1 leading dimension to the tensor
const formattedInputTensor = tensor.unsqueeze(0);
// Load model if not loaded
if (model == null) {
console.log('Loading model...');
const filePath = await MobileModel.download(MODEL_URL);
model = await torch.jit._loadForMobile(filePath);
console.log('Model successfully loaded');
}
// Run inference
const output = await model.forward(formattedInputTensor);
console.log(output);
}
Now that we have code to download and run the model, let's see what it produces!
Remember that it will be slower the first time you run it because it will need to download the model. After that, it should run much faster. It will also need to redownload the model each time you update the code and it reloads on your device. There are plenty of ways to get around this, but we will not cover them in this tutorial.
When you press the capture button now, after the model is done downloading, you should see something similar to the following in the logs:
Loading model...
Model successfully loaded
▼{pred_logits:{…},pred_boxes:{…}}
▼pred_logits:{dtype:"float32",shape:[…]}
dtype:"float32"
►shape:[1,100,92]
▼pred_boxes:{dtype:"float32",shape:[…]}
dtype:"float32"
►shape:[1,100,4]
Hurray! Our model is running! The output seems a bit confusing at first glance though 🤔
Let's unpack this output and format it for the rest of our app to use.
Format model output
In order to make the output easy to use in the rest of our app, we will have to do some transformations of it. Here's a summary of the changes we're making to the ObjectDetector.js
file:
- Import the class labels from the
CoCoClasses.json
file asCOCO_CLASSES
- Create a variable called
probabilityThreshold
set to0.7
meaning we will only consider an object detected if the model has a confidence score of over0.7
for that detection - Extract the
predLogits
andpredBoxes
from the output. Note that "pred" is short for predicted and we use the.squeeze(0)
method to get rid of the initial1
from theshape
we saw in the logged output and just have the data in the rest of the tensor. - Grab the number of predictions from the
predLogits
output object and store it in a variable callednumPredictions
- Create an empty array called
resultBoxes
- Loop over each of the model predictions and do the following:
- Grab the confidence scores of the current prediction for each of the different labels and store them in variable called
confidencesTensor
- Convert the
confidencesTensor
into a tensor of probabilities between 0 and 1 by using thesoftmax
function and store that in a variable calledscores
- Find the index of highest probability class in the
scores
tensor with theargmax
function and save it to a variable calledmaxIdx
- Use the
maxIdx
to grab the highest probability out of thescores
tensor and store it in a variable calledmaxProb
- If the
maxProb
is below ourprobabilityThreshold
or themaxIdx
is beyond the number of classes we have, skip the rest of the loop that adds the prediction to the results with thecontinue
statement - Grab the current box from the
predBoxes
tensor and store it in theboxTensor
variable - Extract the coordinates of the center point of the box (
centerX
,centerY
) as well as the dimensions of the box (boxWidth
,boxHeight
) from theboxTensor
- Calculate the coordinates (
x
,y
) of the top-left corner of the box - Create an array that contains the coordinates of the top-left corner of the box as well as its dimensions and store it in a variable called
bounds
- Create an object that contains the
bounds
array we just created as well as the class label, which we can grab out of theCOCO_CLASSES
array by indexing into it with themaxIdx
- Add the new
match
object to ourresultBoxes
array with the.push
method
- Grab the confidence scores of the current prediction for each of the different labels and store them in variable called
- Log the result boxes so we can inspect them
- Return the result boxes so whoever calls this function can use them
import {
media,
MobileModel,
torch,
torchvision,
} from 'react-native-pytorch-core';
import COCO_CLASSES from './CoCoClasses.json';
const T = torchvision.transforms;
const MODEL_URL =
'https://github.com/facebookresearch/playtorch/releases/download/v0.1.0/detr_resnet50.ptl';
let model = null;
const probabilityThreshold = 0.7;
export default async function detectObjects(image) {
// Get image width and height
const imageWidth = image.getWidth();
const imageHeight = image.getHeight();
// Convert image to blob, which is a byte representation of the image
// in the format height (H), width (W), and channels (C), or HWC for short
const blob = media.toBlob(image);
// Get a tensor from image the blob and also define in what format
// the image blob is.
let tensor = torch.fromBlob(blob, [imageHeight, imageWidth, 3]);
// Rearrange the tensor shape to be [CHW]
tensor = tensor.permute([2, 0, 1]);
// Divide the tensor values by 255 to get values between [0, 1]
tensor = tensor.div(255);
// Resize the image tensor to 3 x min(height, 800) x min(width, 800)
const resize = T.resize(800);
tensor = resize(tensor);
// Normalize the tensor image with mean and standard deviation
const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
tensor = normalize(tensor);
// Unsqueeze adds 1 leading dimension to the tensor
const formattedInputTensor = tensor.unsqueeze(0);
// Load model if not loaded
if (model == null) {
console.log('Loading model...');
const filePath = await MobileModel.download(MODEL_URL);
model = await torch.jit._loadForMobile(filePath);
console.log('Model successfully loaded');
}
// Run inference
const output = await model.forward(formattedInputTensor);
const predLogits = output.pred_logits.squeeze(0);
const predBoxes = output.pred_boxes.squeeze(0);
const numPredictions = predLogits.shape[0];
const resultBoxes = [];
for (let i = 0; i < numPredictions; i++) {
const confidencesTensor = predLogits[i];
const scores = confidencesTensor.softmax(0);
const maxIndex = confidencesTensor.argmax().item();
const maxProb = scores[maxIndex].item();
if (maxProb <= probabilityThreshold || maxIndex >= COCO_CLASSES.length) {
continue;
}
const boxTensor = predBoxes[i];
const [centerX, centerY, boxWidth, boxHeight] = boxTensor.data();
const x = centerX - boxWidth / 2;
const y = centerY - boxHeight / 2;
// Adjust bounds to image size
const bounds = [
x * imageWidth,
y * imageHeight,
boxWidth * imageWidth,
boxHeight * imageHeight,
];
const match = {
objectClass: COCO_CLASSES[maxIndex],
bounds,
};
resultBoxes.push(match);
}
console.log(resultBoxes);
return resultBoxes;
}
That was a lot of changes! Let's test them out to see what our results look like now.
When you press the capture button now, you should see something like this in the logs:
▼[{…},{…},{…}]
▼0:{objectClass:"bowl",bounds:[…]}
objectClass:"bowl"
►bounds:[58.80519703030586,1618.3972656726837,123.18191081285477,98.28389883041382]
►1:{objectClass:"person",bounds:[…]}
►2:{objectClass:"cup",bounds:[…]}
length:3
Now that we have the results in a much more usable format with class labels and a bounding box, let's display our results in the app!
Display results
In order to display the results in our app, all we need to do is make some simple changes in our App.js
file.
Before we make any changes, we should have a screenState
state variable that chooses which screen we show. Currently we are only ever showing the camera screen, so we will need to update that state variable. You can also find a ScreenStates
object that has the 3 different screen states our app can have.
We also should notice that we have a state variable called boundingBoxes
that is currently always an empty array. That boundingBoxes
state variable is passed to our ResultsScreen
, so we will need to make sure we updated it for our results to be displayed.
With that understanding of how our app works, let's make some updates to our App.js
file. All of our changes will be inside our handleImage
function. Here's a summary:
- Update the
image
state variable to be the newlycapturedImage
- Update the
screenState
toScreenStates.LOADING
in order to display theLoadingScreen
while our model downloads and runs - In a
try
block:- Await the result of the
detectObjects
function and store it in a variable callednewBoxes
- Update the
boundingBoxes
state variable to be thenewBoxes
- Update the
screenState
toScreenStates.RESULTS
in order to display theResultsScreen
so we can visually see what objects DETR detected in our image
- Await the result of the
- Add a
catch
block in case there is an error when we run our model:- Call the
handleReset
button to put the app back in theCameraScreen
so we can take a new photo
- Call the
import * as React from 'react';
import {useCallback, useRef, useState} from 'react';
import {
ActivityIndicator,
Button,
SafeAreaView,
StyleSheet,
Text,
TouchableOpacity,
View,
} from 'react-native';
import {Camera, Canvas, Image} from 'react-native-pytorch-core';
import detectObjects from './ObjectDetector';
import CameraScreen from './screens/CameraScreen';
import LoadingScreen from './screens/LoadingScreen';
import ResultsScreen from './screens/ResultsScreen';
const ScreenStates = {
CAMERA: 0,
LOADING: 1,
RESULTS: 2,
};
export default function ObjectDetectionExample() {
const [image, setImage] = useState(null);
const [boundingBoxes, setBoundingBoxes] = useState(null);
const [screenState, setScreenState] = useState(ScreenStates.CAMERA);
// Handle the reset button and return to the camera capturing mode
const handleReset = useCallback(async () => {
setScreenState(ScreenStates.CAMERA);
if (image != null) {
await image.release();
}
setImage(null);
setBoundingBoxes(null);
}, [image, setScreenState]);
// This handler function handles the camera's capture event
async function handleImage(capturedImage) {
setImage(capturedImage);
// Wait for image to process through DETR model and draw resulting image
setScreenState(ScreenStates.LOADING);
try {
const newBoxes = await detectObjects(capturedImage);
setBoundingBoxes(newBoxes);
// Switch to the ResultsScreen to display the detected objects
setScreenState(ScreenStates.RESULTS);
} catch (err) {
// In case something goes wrong, go back to the CameraScreen to take a new picture
handleReset();
}
}
return (
<SafeAreaView style={styles.container}>
{screenState === ScreenStates.CAMERA && (
<CameraScreen onCapture={handleImage} />
)}
{screenState === ScreenStates.LOADING && <LoadingScreen />}
{screenState === ScreenStates.RESULTS && (
<ResultsScreen
image={image}
boundingBoxes={boundingBoxes}
onReset={handleReset}
/>
)}
</SafeAreaView>
);
}
const styles = StyleSheet.create({
container: {
flex: 1,
},
});
With the results of the detectObjects
function hooked up, we are now ready to see the full app working. Load the latest changes into the PlayTorch app and let's test it!
When you press the capture button now, after the model finishes running, you should see the picture you took but with labeled boxes around the detected objects, just like the screenshot below:
Wrapping up
We hope you enjoyed getting an object detection model up and running with PlayTorch.
Providing feedback
Was something unclear? Did you run into any bugs? Is there another tutorial or model you'd like to see from us?
We would love to hear how we can improve.