XLA Kernel Fusion / Selection Expectations #23630
Labels
question
Further information is requested
stat:awaiting response from contributor
Awaiting response from contributor/author
Hi team,
I am writing to inquiry about what's the expected behavior of XLA kernel selection / fusion. My interest primarily lies in XLA<>GPU but I am also eager to learn about XLA<>TPU kernel selection process.
What I have done is to start with a simple torch model, leverage torch_xla and torch.autocast to export the model to stablehlo with fp16 operations (with saving the model weights as a numpy array, as they are treated as model inputs). Then, we leveraged XLA AOT compile with a pre-processing step where I freezed all the model weights as constants. And I have some questions on the compiled kernels.
From what I can see (see the fused report below), the convolutions are consistently fused as
cudnn-conv-bias-activation
. However, for the linear layers, some are fused asgemm_fusion_dot
but some becomefused_reduce
. Taking a closer look at thefused_reduce
, thereduce
op after matrix multiplication happens actually in FP32, which inserts two casts (convert
s) before and after the reduce, which slows down the inference.module_0001.IrToHlo.106.sm_8.6_gpu_after_optimizations-memory-usage-report.txt
Therefore, my question is: what is the mechanism on how XLA selects which kernel to be used after compilation? How can we modify the behavior so the better kernel will be selected? Is there any documentations describing the fusion / kernel selection logics in the different optimization passes of XLA that I can read about?
Thank you advance.
Attached here is the torch -> torch_xla -> hlo export flow
The text was updated successfully, but these errors were encountered: