# pytorch_to_onnx_int8_squeezenet.py
import sys, os, glob, io, time, urllib.request, numpy as np, onnx, onnxruntime as ort
from PIL import Image
import torch, torchvision
import torchvision.transforms as T
from onnxruntime.quantization import (
CalibrationDataReader,
QuantFormat,
QuantType,
calibrate,
quantize_static,
)
use_npu = True if len(sys.argv) >= 2 and sys.argv[1] == '--use-npu' else False
def download_file_if_not_exists(path, url):
if not os.path.exists(path):
os.makedirs(os.path.dirname(path), exist_ok=True)
print(f"Downloading {path} from {url}...")
urllib.request.urlretrieve(url, path)
return path
weights = torchvision.models.SqueezeNet1_1_Weights.DEFAULT
IMAGE_PATH = download_file_if_not_exists('images/boa-constrictor.jpg', 'https://cdn.edgeimpulse.com/qc-ai-docs/examples/boa-constrictor.jpg')
# Load PyTorch SqueezeNet1_1 from torchvision
device = "cpu"
model = torchvision.models.squeezenet1_1(weights)
model.eval().to(device)
# Export to ONNX (fp32)
os.makedirs("models", exist_ok=True)
onnx_fp32 = "models/squeezenet1_1_fp32.onnx"
input_size = 224
dummy = torch.randn(1, 3, input_size, input_size, device=device)
torch.onnx.export(
model, dummy, onnx_fp32,
input_names=["input"],
output_names=["logits"],
opset_version=13,
do_constant_folding=True,
# Keep the model shape static for QNNExecutionProvider / NPU execution.
)
onnx.checker.check_model(onnx.load(onnx_fp32))
print(f"Exported FP32 ONNX -> {onnx_fp32}")
# Provide a calibration data reader for static INT8 quantization
class ImageFolderDataReader(CalibrationDataReader):
def __init__(self, image_paths):
self.image_paths = image_paths
self.transform = T.Compose([
T.Resize(256), T.CenterCrop(224),
T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
self._iter = None
def get_next(self):
if self._iter is None:
self._iter = iter(self.image_paths)
try:
p = next(self._iter)
except StopIteration:
return None
img = Image.open(p).convert("RGB")
x = self.transform(img).unsqueeze(0).numpy()
return {"input": x}
# Replace with representative images from your domain
calib = ImageFolderDataReader([IMAGE_PATH])
# Find the ONNX input name (matches "input" above, but we read it to be robust)
m = onnx.load(onnx_fp32)
onnx_input_name = m.graph.input[0].name
onnx_int8 = "models/squeezenet1_1_int8.onnx"
# Use QDQ format (widely supported); uint8 activations + int8 weights is a common choice
quantize_static(
model_input=onnx_fp32,
model_output=onnx_int8,
calibration_data_reader=calib,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8
)
onnx.checker.check_model(onnx.load(onnx_int8))
print(f"Quantized INT8 ONNX -> {onnx_int8}")
# Use HTP backend of libQnnTFLiteDelegate.so (NPU) when --use-npu is passed in (otherwise CPU)
providers = []
if use_npu:
providers.append(("QNNExecutionProvider", {
"backend_type": "htp",
}))
else:
providers.append("CPUExecutionProvider")
input_shape = (1, 3, 224, 224)
img = Image.open(IMAGE_PATH).convert("RGB").resize((224, 224))
input_data = np.expand_dims(np.transpose(np.array(img, dtype=np.float32) / 255.0, (2, 0, 1)), 0)
#input_data = np.random.rand(*input_shape).astype(np.float32)
so = ort.SessionOptions()
sess = ort.InferenceSession(onnx_int8, sess_options=so, providers=providers)
actual_providers = sess.get_providers()
print(f"Using providers: {actual_providers}") # Show which providers are actually loaded
inputs = sess.get_inputs()
outputs = sess.get_outputs()
_ = sess.run(None, { sess.get_inputs()[0].name: input_data })
# Run 10x so we can calculate avg. runtime per inference
start = time.perf_counter()
for i in range(10):
out = sess.run(None, { sess.get_inputs()[0].name: input_data })
end = time.perf_counter()
def softmax(x, axis=-1):
# subtract max for numerical stability
x_max = np.max(x, axis=axis, keepdims=True)
e_x = np.exp(x - x_max)
return e_x / np.sum(e_x, axis=axis, keepdims=True)
scores = softmax(np.squeeze(out[0], axis=0))
# Take top 5
top_k_idx = scores.argsort()[-5:][::-1]
print("\nTop-5 predictions:")
for i in top_k_idx:
label = weights.meta["categories"][i] if i < len(weights.meta["categories"]) else f"Class {i}"
print(f"{label}: score={scores[i]}")
print("")
print(f'Inference took (on average): {((end - start) * 1000) / 10:.4g}ms. per image')