Skip to content

Conversation

@qjia7
Copy link
Contributor

@qjia7 qjia7 commented Nov 26, 2025

This pull request introduces a new mechanism for efficiently performing tensor data type casting on the WebGPU backend by dynamically generating minimal ONNX models for the Cast operator and caching corresponding inference sessions. The changes include a new utility for building ONNX Cast models, a session cache for reusing Cast operator sessions, and an implementation of a fast Cast method in the WebGPU device interface.

ONNX Cast model generation and caching:

  • Added cast_model_builder.cpp and cast_model_builder.h to manually construct minimal ONNX models for the Cast operator, enabling runtime generation of models for arbitrary input/output tensor types without depending on the ONNX library. [1] [2]
  • Implemented a thread-safe CastSessionCache in interface.cpp that stores and reuses inference sessions for each unique input/output type pair, reducing model creation and session initialization overhead.

WebGPU device interface enhancements:

  • Added a new Cast method to the InterfaceImpl class in interface.cpp that uses the cached ONNX Cast sessions and ONNX Runtime's IOBinding to efficiently perform type conversion on tensors using the WebGPU execution provider.
  • Included a helper function in InterfaceImpl to determine the element size for each ONNX tensor data type, ensuring correct memory allocation during casting.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements efficient Cast operator support for the WebGPU backend by dynamically generating minimal ONNX models and caching inference sessions. The implementation enables type conversion operations to be performed on WebGPU devices without requiring external ONNX library dependencies.

  • Adds manual protobuf-based ONNX model generation for Cast operations
  • Implements thread-safe session caching to avoid redundant model creation
  • Provides WebGPU-specific Cast method using ONNX Runtime's IOBinding

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
src/webgpu/cast_model_builder.h Declares the function to create ONNX Cast model bytes from input/output types
src/webgpu/cast_model_builder.cpp Implements manual protobuf encoding to generate minimal ONNX Cast operator models without ONNX library dependency
src/webgpu/interface.cpp Adds CastSessionCache for thread-safe session reuse and implements Cast method with element size helper for tensor creation

Comment on lines 227 to 236
bool Cast(void* input, void* output, ONNXTensorElementDataType input_type, ONNXTensorElementDataType output_type, size_t element_count) override {
if (!ort_allocator_) {
throw std::runtime_error("WebGPU allocator not initialized");
}

// Get or create cached session for this type pair
OrtSession* session = g_cast_session_cache.GetOrCreate(input_type, output_type);
if (!session) {
return false;
}
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Cast implementation is missing a check for when input and output types are the same. Both the CPU and CUDA implementations throw an error in this case:

if (input_type == output_type)
  throw std::runtime_error("Cast - input and output types are the same");

This check should be added at the beginning of the Cast method to maintain consistency with other device implementations and prevent unnecessary session creation and inference for no-op casts.

Copilot uses AI. Check for mistakes.
@qjia7 qjia7 marked this pull request as ready for review November 27, 2025 02:22
@qjia7
Copy link
Contributor Author

qjia7 commented Nov 27, 2025

@kunal-vaishnavi @fs-eire @guschmue In the latest commit, I move the cached cast_sessions_ from InterfaceImpl level to OrtGlobals due to that I found the release of InterfaceImpl happens after env_ which results webgpu context been cleared, but later cast_session release will trigger context is not found error.

buffer.insert(buffer.end(), message.begin(), message.end());
}

std::vector<uint8_t> CreateCastModelBytes(ONNXTensorElementDataType input_type, ONNXTensorElementDataType output_type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we generalize this process for future 1-op graphs? We have considered implementing this idea to perform computations for some search and sampling algorithms with on-device memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify what you mean by ‘generalize this process’? What specific changes would you like me to make?

@qjia7 qjia7 marked this pull request as draft December 4, 2025 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants