optimization

Optimized NN Inference using custom Triton kernels

Implemented a high-performance linear layer (both forward and backward pass) with (optional) activation layer fusion using OpenAI’s Triton. The use of the custom Triton-based linear layer demonstrated up to 1.6x speedup in training FlanT5-Base on the Samsum dataset and up to 3.5x speedup in inference. Automated the patching of PyTorch’s nn.LinearLayer and associated activation layers to the new custom layers for inference using torch.fx for pattern matching and CUDA Graphs for reducing overheads.