import cv2 import numpy as np import json import argparse from hailo_platform.pyhailort import HailoRT # --- Argparse --- parser = argparse.ArgumentParser(description="Pokémon Classifier Inference with Hailo-8") parser.add_argument("--model", choices=["1", "2"], required=True, help="1 = ResNet50, 2 = Xception") args = parser.parse_args() # --- Paths --- if args.model == "1": hef_path = "../models/ResNet50/pokedex_ResNet50.hef" json_path = "../models/ResNet50/class_names.json" input_shape = (224, 224) elif args.model == "2": hef_path = "../models/Xception/pokedex_Xception.hef" json_path = "../models/Xception/class_names.json" input_shape = (256, 256) else: raise ValueError("Invalid model selection") # --- Load class names --- with open(json_path, "r") as f: class_names = json.load(f) # --- Setup device and network --- device = HailoRT.Device() hef = HailoRT.Hef(hef_path) network_group = device.create_hef_group(hef) input_info = network_group.get_input_vstream_infos()[0] output_info = network_group.get_output_vstream_infos()[0] # --- Open webcam and capture image --- cap = cv2.VideoCapture(0) if not cap.isOpened(): print("-- Unable to open webcam") exit(1) print("-- Capturing image...") ret, frame = cap.read() cap.release() if not ret: print("-- Failed to capture image") exit(1) # --- Preprocess image --- image = cv2.resize(frame, input_shape) image = image.astype(np.float32) # Standard Hailo normalization (if your model expects ImageNet style) image -= [123.68, 116.779, 103.939] # image /= 255.0 # only if your model was trained with [0,1] normalization # NHWC → NCHW image = np.transpose(image, (2, 0, 1)) # (H, W, C) → (C, H, W) image = np.expand_dims(image, axis=0) # Add batch dimension → (1, C, H, W) # --- Inference --- with HailoRT.VirtualStreams(input_info, output_info, network_group) as (input_vstreams, output_vstreams): input_vstreams[0].send(image) output_data = output_vstreams[0].recv() # --- Postprocess --- predicted_idx = int(np.argmax(output_data)) predicted_name = class_names[predicted_idx] print(f"🎯 Predicted Pokémon: {predicted_name}")