-
Notifications
You must be signed in to change notification settings - Fork 238
Implement Cast for webgpu #1895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements efficient Cast operator support for the WebGPU backend by dynamically generating minimal ONNX models and caching inference sessions. The implementation enables type conversion operations to be performed on WebGPU devices without requiring external ONNX library dependencies.
- Adds manual protobuf-based ONNX model generation for Cast operations
- Implements thread-safe session caching to avoid redundant model creation
- Provides WebGPU-specific Cast method using ONNX Runtime's IOBinding
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| src/webgpu/cast_model_builder.h | Declares the function to create ONNX Cast model bytes from input/output types |
| src/webgpu/cast_model_builder.cpp | Implements manual protobuf encoding to generate minimal ONNX Cast operator models without ONNX library dependency |
| src/webgpu/interface.cpp | Adds CastSessionCache for thread-safe session reuse and implements Cast method with element size helper for tensor creation |
| bool Cast(void* input, void* output, ONNXTensorElementDataType input_type, ONNXTensorElementDataType output_type, size_t element_count) override { | ||
| if (!ort_allocator_) { | ||
| throw std::runtime_error("WebGPU allocator not initialized"); | ||
| } | ||
|
|
||
| // Get or create cached session for this type pair | ||
| OrtSession* session = g_cast_session_cache.GetOrCreate(input_type, output_type); | ||
| if (!session) { | ||
| return false; | ||
| } |
Copilot
AI
Nov 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Cast implementation is missing a check for when input and output types are the same. Both the CPU and CUDA implementations throw an error in this case:
if (input_type == output_type)
throw std::runtime_error("Cast - input and output types are the same");This check should be added at the beginning of the Cast method to maintain consistency with other device implementations and prevent unnecessary session creation and inference for no-op casts.
|
@kunal-vaishnavi @fs-eire @guschmue In the latest commit, I move the cached cast_sessions_ from |
| buffer.insert(buffer.end(), message.begin(), message.end()); | ||
| } | ||
|
|
||
| std::vector<uint8_t> CreateCastModelBytes(ONNXTensorElementDataType input_type, ONNXTensorElementDataType output_type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we generalize this process for future 1-op graphs? We have considered implementing this idea to perform computations for some search and sampling algorithms with on-device memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify what you mean by ‘generalize this process’? What specific changes would you like me to make?
This pull request introduces a new mechanism for efficiently performing tensor data type casting on the WebGPU backend by dynamically generating minimal ONNX models for the Cast operator and caching corresponding inference sessions. The changes include a new utility for building ONNX Cast models, a session cache for reusing Cast operator sessions, and an implementation of a fast Cast method in the WebGPU device interface.
ONNX Cast model generation and caching:
cast_model_builder.cppandcast_model_builder.hto manually construct minimal ONNX models for the Cast operator, enabling runtime generation of models for arbitrary input/output tensor types without depending on the ONNX library. [1] [2]CastSessionCacheininterface.cppthat stores and reuses inference sessions for each unique input/output type pair, reducing model creation and session initialization overhead.WebGPU device interface enhancements:
Castmethod to theInterfaceImplclass ininterface.cppthat uses the cached ONNX Cast sessions and ONNX Runtime's IOBinding to efficiently perform type conversion on tensors using the WebGPU execution provider.InterfaceImplto determine the element size for each ONNX tensor data type, ensuring correct memory allocation during casting.