import cv2
import numpy as np
import json
import argparse
from hailo_platform.pyhailort import HailoRT

# --- Argparse ---
parser = argparse.ArgumentParser(description="Pokémon Classifier Inference with Hailo-8")
parser.add_argument("--model", choices=["1", "2"], required=True, help="1 = ResNet50, 2 = Xception")
args = parser.parse_args()

# --- Paths ---
if args.model == "1":
    hef_path = "../models/ResNet50/pokedex_ResNet50.hef"
    json_path = "../models/ResNet50/class_names.json"
    input_shape = (224, 224)
elif args.model == "2":
    hef_path = "../models/Xception/pokedex_Xception.hef"
    json_path = "../models/Xception/class_names.json"
    input_shape = (256, 256)
else:
    raise ValueError("Invalid model selection")

# --- Load class names ---
with open(json_path, "r") as f:
    class_names = json.load(f)

# --- Setup device and network ---
device = HailoRT.Device()
hef = HailoRT.Hef(hef_path)
network_group = device.create_hef_group(hef)

input_info = network_group.get_input_vstream_infos()[0]
output_info = network_group.get_output_vstream_infos()[0]

# --- Open webcam and capture image ---
cap = cv2.VideoCapture(0)
if not cap.isOpened():
    print("-- Unable to open webcam")
    exit(1)

print("-- Capturing image...")
ret, frame = cap.read()
cap.release()

if not ret:
    print("-- Failed to capture image")
    exit(1)

# --- Preprocess image ---
image = cv2.resize(frame, input_shape)
image = image.astype(np.float32)

# Standard Hailo normalization (if your model expects ImageNet style)
image -= [123.68, 116.779, 103.939]
# image /= 255.0  # only if your model was trained with [0,1] normalization

# NHWC → NCHW
image = np.transpose(image, (2, 0, 1))   # (H, W, C) → (C, H, W)
image = np.expand_dims(image, axis=0)    # Add batch dimension → (1, C, H, W)

# --- Inference ---
with HailoRT.VirtualStreams(input_info, output_info, network_group) as (input_vstreams, output_vstreams):
    input_vstreams[0].send(image)
    output_data = output_vstreams[0].recv()

# --- Postprocess ---
predicted_idx = int(np.argmax(output_data))
predicted_name = class_names[predicted_idx]
print(f"🎯 Predicted Pokémon: {predicted_name}")