Skip to content
Snippets Groups Projects
Commit 6cb09929 authored by michael.divia's avatar michael.divia
Browse files

MayBe

parent a4d04f4e
No related branches found
No related tags found
No related merge requests found
import cv2
import numpy as np
import json import json
import argparse import argparse
import os import os
from hailo_platform.pyhailort import HailoRT import numpy as np
from picamera2 import Picamera2
from picamera2.devices.hailo import Hailo
import cv2
# --- Argparse --- # --- Argparse ---
parser = argparse.ArgumentParser(description="Pokémon Classifier Inference with Hailo-8") parser = argparse.ArgumentParser(description="Pokémon Classifier Inference with Hailo-8")
...@@ -14,11 +15,9 @@ args = parser.parse_args() ...@@ -14,11 +15,9 @@ args = parser.parse_args()
if args.model == "1": if args.model == "1":
hef_path = "../models/ResNet50/pokedex_ResNet50.hef" hef_path = "../models/ResNet50/pokedex_ResNet50.hef"
json_path = "../models/ResNet50/class_names.json" json_path = "../models/ResNet50/class_names.json"
input_shape = (224, 224)
elif args.model == "2": elif args.model == "2":
hef_path = "../models/Xception/pokedex_Xception.hef" hef_path = "../models/Xception/pokedex_Xception.hef"
json_path = "../models/Xception/class_names.json" json_path = "../models/Xception/class_names.json"
input_shape = (256, 256)
else: else:
raise ValueError("Invalid model selection") raise ValueError("Invalid model selection")
...@@ -26,57 +25,37 @@ else: ...@@ -26,57 +25,37 @@ else:
with open(json_path, "r") as f: with open(json_path, "r") as f:
class_names = json.load(f) class_names = json.load(f)
# --- Setup device and network --- # --- Run inference ---
device = HailoRT.Device() with Hailo(hef_path) as hailo:
hef = HailoRT.Hef(hef_path) # Get model input shape (e.g., 224x224x3 or 256x256x3)
network_group = device.create_hef_group(hef) model_h, model_w, _ = hailo.get_input_shape()
input_info = network_group.get_input_vstream_infos()[0]
output_info = network_group.get_output_vstream_infos()[0]
# --- Open webcam and capture image --- # Setup and start the camera
cap = cv2.VideoCapture(0) picam2 = Picamera2()
if not cap.isOpened(): main = {'size': (model_w, model_h), 'format': 'RGB888'}
print("-- Unable to open webcam") config = picam2.create_preview_configuration(main)
exit(1) picam2.start(config)
print("-- Capturing image...") print("-- Capturing image...")
ret, frame = cap.read() frame = picam2.capture_array()
cap.release()
if not ret: # Optionally display the captured image
print("-- Failed to capture image")
exit(1)
# --- Try to display the captured image ---
try: try:
cv2.imshow("Captured Image", frame) cv2.imshow("Captured Image", frame)
print("-- Press any key to continue...") print("-- Press any key to continue...")
cv2.waitKey(0) cv2.waitKey(0)
cv2.destroyAllWindows() cv2.destroyAllWindows()
except cv2.error as e: except cv2.error:
print("-- GUI display failed, saving and showing with feh instead...") print("-- GUI display failed, saving and showing with feh instead...")
output_path = "/tmp/captured.png" output_path = "/tmp/captured.png"
cv2.imwrite(output_path, frame) cv2.imwrite(output_path, frame)
os.system(f"feh --fullscreen {output_path}") os.system(f"feh --fullscreen {output_path}")
# --- Preprocess image --- # Run inference
image = cv2.resize(frame, input_shape) print("-- Running inference...")
image = image.astype(np.float32) inference_results = hailo.run(frame)
# Standard Hailo normalization
image -= [123.68, 116.779, 103.939]
# NHWC → NCHW
image = np.transpose(image, (2, 0, 1)) # (H, W, C) → (C, H, W)
image = np.expand_dims(image, axis=0) # Add batch dimension → (1, C, H, W)
# --- Inference ---
with HailoRT.VirtualStreams(input_info, output_info, network_group) as (input_vstreams, output_vstreams):
input_vstreams[0].send(image)
output_data = output_vstreams[0].recv()
# --- Postprocess --- # Postprocess: find predicted class
predicted_idx = int(np.argmax(output_data)) predicted_idx = int(np.argmax(inference_results))
predicted_name = class_names[predicted_idx] predicted_name = class_names[predicted_idx]
print(f"-- Predicted Pokémon: {predicted_name}") print(f"-- Predicted Pokémon: {predicted_name}")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment