This is using the gemma-llm python library which uses JAX in the background: https://gemma-llm.readthedocs.io/en/latest/colab_finetuning....