AI Edge Torch是Google最近发布的一个开源Python库,旨在简化将PyTorch模型部署到移动设备和IoT设备上的过程。作为Google AI Edge生态系统的一部分,这个库为开发者提供了一种便捷的方式,可以将PyTorch模型转换为TensorFlow Lite (TFLite)格式,并在移动端和IoT设备上实现高性能的本地推理。
AI Edge Torch的主要特性和优势包括:
直接的PyTorch集成 - 提供与PyTorch原生感觉一致的API,使转换过程变得简单直观。
优秀的CPU性能和初步的GPU支持 - 在CPU上提供出色的推理性能,同时也开始支持GPU加速。
广泛的模型验证 - 已在70多个来自torchvision、timm、torchaudio和HuggingFace的模型上进行了验证测试。
高覆盖率 - 支持超过70%的PyTorch core_aten算子。
与现有TFLite运行时兼容 - 无需更改部署代码即可使用。
支持Model Explorer可视化 - 在工作流程的多个阶段提供模型可视化功能。
AI Edge Torch的核心功能是PyTorch转换器,它可以将PyTorch模型转换为TFLite格式。使用这个转换器非常简单,只需几行代码即可完成:
import torch import torchvision import ai_edge_torch # 使用预训练的ResNet18模型 resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1) sample_inputs = (torch.randn(1, 3, 224, 224),) # 将PyTorch模型转换为TFLite格式 edge_model = ai_edge_torch.convert(resnet18.eval(), sample_inputs) edge_model.export("resnet18.tflite")
这个简单的示例展示了如何将预训练的ResNet18模型转换为TFLite格式。开发者可以轻松地将这个过程应用到自己的自定义PyTorch模型上。
除了基本的PyTorch转换器之外,AI Edge Torch还提供了一个Generative API,专门用于处理大型语言模型(LLMs)和基于Transformer的模型。这个API支持模型创作和量化,以实现更好的设备端性能。
Generative API的主要特点包括: