Skip to content
Snippets Groups Projects
Commit 67858bd1 authored by michael.divia's avatar michael.divia
Browse files

Predict 4 Rpi

parent 1949d2bd
Branches
No related tags found
No related merge requests found
import cv2
import numpy as np
import json
from hailo_platform.pyhailort import HailoRT
# Load class names
with open("../models/ResNet50/class_names.json", "r") as f:
class_names = json.load(f)
# --- Load HEF and configure device ---
hef_path = "../models/ResNet50/resnet50.hef"
device = HailoRT.Device()
hef = HailoRT.Hef(hef_path)
configured_network_group = device.create_hef_group(hef)
input_vstream_info = configured_network_group.get_input_vstream_infos()[0]
output_vstream_info = configured_network_group.get_output_vstream_infos()[0]
# --- Open webcam ---
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("-- Unable to open webcam")
exit()
print("-- Taking picture...")
ret, frame = cap.read()
cap.release()
if not ret:
print("-- Failed to capture image")
exit()
# --- Preprocess image ---
image = cv2.resize(frame, (224, 224))
image = image.astype(np.float32) / 255.0 # Normalize to [0, 1]
image = np.expand_dims(image, axis=0) # Add batch dimension
image = np.transpose(image, (0, 3, 1, 2)) # NHWC ? NCHW if required (check your model)
# --- Inference ---
with HailoRT.VirtualStreams(input_vstream_info, output_vstream_info, configured_network_group) as (input_vstreams, output_vstreams):
input_vstreams[0].send(image)
output_data = output_vstreams[0].recv()
# --- Postprocess ---
predicted_idx = int(np.argmax(output_data))
predicted_name = class_names[predicted_idx]
print(f"-- Predicted Pokémon: {predicted_name}")
network:
name: custom_resnet50
framework: onnx
path: resnet50_pokemon.onnx
input_shape: [1, 224, 224, 3]
mean_std:
mean: [123.68, 116.779, 103.939]
std: [1.0, 1.0, 1.0]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment