Optimized NN Inference using custom Triton kernels

Implemented a high-performance linear layer (both forward and backward pass) with (optional) activation layer fusion using OpenAI’s Triton.

  • The use of the custom Triton-based linear layer demonstrated up to 1.6x speedup in training FlanT5-Base on the Samsum dataset and up to 3.5x speedup in inference.
  • Automated the patching of PyTorch’s nn.LinearLayer and associated activation layers to the new custom layers for inference using torch.fx for pattern matching and CUDA Graphs for reducing overheads.
Previous