One of the easiest approache is torch.compile, it's the latest iteration of pytorch compiler (previous methods were : TorchScript and FX Tracing.)
You simply write
model = torch.compile(model)
"Across these 163 open-source models torch.compile works 93% of time, and the model runs 43% faster in training on an NVIDIA A100 GPU. At Float32 precision, it runs 21% faster on average and at AMP Precision it runs 51% faster on average."[1]
What google is trying to do, is to involve more people in the R&D of these kind of methods.
The near term promise is that you can use AMD, CUDA, TPUs, CPUs etc without explicit vendor support for the framework on which the model was developed.
Disclaimer: I will be very handwavey, reality is complex.
This is achieved by compiling the graph into some intermediate representation. And then implementing the right backend. For projects here, look at stableHLO, IREE, openXLA.
You can argue that Jax's jit compiler is a form of such compiler, mapping the traced operations down to XLA, which then does its own bit of magic to make it work on your backend.
It's transformations and abstractions all the way down.
What's the actual state of these "ML compilers" currently, and what is rhe near term promise?