JAX-COSMO: une librairie auto-differentiable et executable sur GPU pour la Cosmology.

Extrait de JAX-COSMO: An End-to-End Differentiable and GPU Accelerated Cosmology Library (https://arxiv.org/abs/2302.05163)

Nous présentons jax-cosmo, une bibliothèque pour les calculs en cosmologique automatiquement différentiables. jax-cosmo utilise la bibliothèque JAX, qui a créé un nouvel écosystème, en particulier dans le domaine de la programmation probabiliste. En plus de l’accélération par vectorisation, de la compilation juste à temps et de l’optimisation automatique du code pour différents matériels (CPU, GPU, TPU), JAX expose un mécanisme de différenciation automatique (autodiff). Grâce à autodiff, jax-cosmo donne accès aux dérivées des vraisemblances cosmologiques par rapport à n’importe lequel de leurs paramètres, et permet ainsi une gamme d’algorithmes puissants d’inférence bayésienne, autrement impraticables en cosmologie, tels que Hamiltonian Monte Carlo et l’inférence variationnelle. Dans sa version initiale, jax-cosmo implémente les spectres de puissance linéaires et non linéaires (en utilisant Halofit ou la fonction de transfert d’Eisenstein et Hu), ainsi que les spectres de puissance angulaire (Cl) avec l’approximation de Limber pour les sondes weak-lensing et number counts, tous différentiables par rapport aux paramètres cosmologiques et à leurs autres entrées. Nous illustrons comment la différenciation automatique peut changer la donne pour des tâches courantes impliquant des calculs de la matrice de Fisher ou une inférence postérieure complète avec des techniques basées sur le gradient (par exemple, Hamiltonian Monte Carlo). En particulier, nous montrons que les matrices de Fisher sont désormais rapides et exactes, qu’elles ne nécessitent plus de réglage fin et qu’elles sont elles-mêmes différentiables par rapport aux paramètres de la vraisemblance, ce qui permet d’optimiser des enquêtes complexes par une simple descente de gradient. Enfin, en utilisant une analyse 3x2pt du Dark Energy Survey Year 1 comme référence, nous démontrons comment jax-cosmo peut être combiné avec des langages de programmation probabilistes tels que NumPyro pour effectuer une inférence a posteriori avec des algorithmes tels que l’échantillonneur (No-U-Turns), une inférence variationnelle par différenciation automatique (ADVI) et un transport neuronal HMC (NeuTra). Nous montrons que la taille effective de l’échantillon par nœud (1 GPU ou 32 CPU) par heure de temps de mur est environ 5 fois meilleure pour un échantillonneur JAX NUTS par rapport à l’échantillonneur bien optimisé Cobaya Metropolis-Hasting. Nous démontrons en outre que la normalisation des flux à l’aide du transport neuronal est une méthodologie prometteuse pour la validation des modèles aux premiers stades de l’analyse.