= JAX – Moderní nástroj pro diferenciální programování a výpočetní akceleraci =
TOC
== Co je JAX? == JAX je open‑source knihovna od Google, která kombinuje NumPy‑like API s automatickým diferenciálním výpočtem a kompilací na různé hardwarové akcelerátory (CPU, GPU, TPU). Díky funkci jit (just‑in‑time) a transformacím jako grad, vmap, pmap a jit umožňuje psát čistý, funkcionální kód a nechat JAX optimalizovat a distribuovat výpočty.
== Historie a vývoj == | ^ Rok ^ | ^ Událost ^ | | 2018 | První veřejná verze JAX (v0.1) – zaměřená na automatické diferencování pomocí Autograd. | | 2019 | Přidání XLA (Accelerated Linear Algebra) backendu, podpora GPU a TPU. | | 2020 | Vydání v0.2 – zavedení jit, vmap a pmap. | | 2021 | JAX 1.0 – stabilní API, rozšířená podpora pro float8, bfloat16, a mixed‑precision. | | 2022 | Integrace s Flax, Haiku a Optax – ekosystém pro modelování, trénink a optimalizaci. | | 2023 | JAX 2.0 – podpora GPU‑TensorCore a TPU‑v5 instrukcí, rozšířené pjit a sharding. | | 2024 | Rozšířený JAX 3.0 preview – podpora distributed arrays (GDA) a sparse operací. |
== Základní koncepty == === 1. Funkcionální programování === JAX vyžaduje, aby funkce byly čisté (bez vedlejších efektů). To umožňuje:
Transformace (grad, jit, vmap, pmap, pmap, pjit) – mohou být aplikovány na libovolnou funkci. Deterministické chování – důležité pro reprodukovatelnost experimentů.
2. Autodiff (automatické diferenciace)
grad(f) – vrací funkci, která počítá gradient skalárního výstupu. value_and_grad(f) – vrací jak hodnotu, tak gradient. jacfwd, jacrev – Jacobian pomocí forward‑ nebo reverse‑mode.
=== 3. Just‑In‑Time kompilace (JIT) === jit(f) kompiluje funkci pomocí XLA a uloží optimalizovaný kód pro daný hardware. Při opakovaném volání je výkon až 10‑30× rychlejší než čistý NumPy.
=== 4. Vektorizace (vmap) === vmap(f) automaticky mapuje funkci přes batche, čímž eliminuje potřebu explicitních for‑loopů a umožňuje SIMD‑styl výpočty.
5. Paralelní a distribuovaná výpočty (pmap, pjit)
pmap(f) – paralelní mapování přes více zařízení (GPU/TPU) na jednom hostu. pjit(f, in_sharding, out_sharding) – explicitní sharding pro distribuované výpočty napříč clusterem.
== Instalace ==
Instalace z PyPI (CPU‑only) pip install --upgrade "jax[cpu]" Instalace s GPU (CUDA 12) pip install --upgrade "jax[cuda12_cudnn89]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html Instalace s TPU (v Cloud Shell) pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
== Jednoduchý příklad – lineární regrese ==
import jax.numpy as jnp from jax import grad, jit, random Generování syntetických dat key = random.PRNGKey(0) X = random.normal(key, (1000, 3)) true_w = jnp.array([1.5, -2.0, 0.7]) y = X @ true_w + 0.1 * random.normal(key, (1000,)) Definice ztrátové funkce def loss(w, X, y): preds = X @ w return jnp.mean((preds - y) ** 2) Gradient a JIT‑ovaná verze grad_loss = jit(grad(loss)) Optimalizace (SGD) def sgd_step(w, lr, X, y): g = grad_loss(w, X, y) return w - lr * g Trénink w = jnp.zeros(3) lr = 0.01 for epoch in range(500): w = sgd_step(w, lr, X, y) print("Naučené váhy:", w)
Co se zde děje?
Používáme jax.numpy (drop‑in replacement za NumPy). jit kompiluje gradientní výpočet jednou a pak ho znovu používá. Všechny operace jsou vektorové a běží na GPU/TPU, pokud jsou k dispozici.
Pokročilé funkce
=== 1. Vmap – batched inference ===
def model(params, x): w, b = params return jnp.dot(x, w) + b batched_model = jax.vmap(model, in_axes=(None, 0)) params = (jnp.array([0.2, -0.5]), 0.1) xs = jnp.arange(10).reshape(5, 2) # 5 příkladů, 2 dimenze print(batched_model(params, xs))
=== 2. Pmap – data‑parallel trénink na 8 TPU ===
from jax import pmap def step(state, batch): grads = grad(loss)(state['params'], batch['X'], batch['y']) new_params = jax.tree_util.tree_map(lambda p, g: p - 0.001 * g, state['params'], grads) return {'params': new_params}, None Rozdělení dat na 8 zařízení def shard(data): return data.reshape(8, -1, *data.shape[1:]) state = {'params': jnp.zeros((3,))} batch = {'X': shard(X), 'y': shard(y)} state, _ = pmap(step, axis_name='devices')(state, batch)
=== 3. Pjit a sharding – distribuované matice ===
from jax.experimental import pjit from jax.sharding import Mesh, PartitionSpec as P mesh = Mesh(jax.devices(), ('data',)) @pjit def matmul(A, B): return jnp.dot(A, B) A = jnp.arange(40964096).reshape(4096, 4096) B = jnp.arange(40964096).reshape(4096, 4096) Sharding specifikace A_sharded = jax.device_put(A, jax.sharding.NamedSharding(mesh, P('data', None))) B_sharded = jax.device_put(B, jax.sharding.NamedSharding(mesh, P(None, 'data'))) C = matmul(A_sharded, B_sharded)
JAX a TPU – ideální dvojice
XLA backend: JAX automaticky překládá operace do XLA, což je optimalizovaný kompilátor používaný i v Google TPU. bfloat16 a float8: Podpora těchto formátů umožňuje vyšší throughput a nižší spotřebu paměti na TPU. Sparsity: V JAX 3.0 je experimentální podpora pro sparse matice, což je důležité pro velké jazykové modely (LLM).
Příklad – trénink Transformeru na TPU v4
import flax.linen as nn import optax from flax.training import train_state Jednoduchý Transformer blok (zkrácený) class SimpleTransformer(nn.Module): d_model: int = 512 n_head: int = 8 n_layer: int = 6 @nn.compact def __call__(self, x): for _ in range(self.n_layer): x = nn.SelfAttention(num_heads=self.n_head, qkv_features=self.d_model)(x) x = nn.Dense(self.d_model)(x) return x Inicializace rng = jax.random.PRNGKey(0) model = SimpleTransformer() params = model.init(rng, jnp.ones((1, 128, 512)))['params'] Optimizer (AdamW) tx = optax.adamw(1e-4) state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) JIT‑ovaná tréninková smyčka @jax.jit def train_step(state, batch): def loss_fn(params): logits = state.apply_fn({'params': params}, batch['inputs']) loss = optax.softmax_cross_entropy(logits, batch['targets']).mean() return loss grads = jax.grad(loss_fn)(state.params) return state.apply_gradients(grads=grads) Příklad batchu (předpokládá se, že je již shardován na TPU) batch = {'inputs': jnp.ones((8, 128, 512)), # 8 devices 'targets': jnp.ones((8, 128, 512))} state = train_step(state, batch) print("Tréninková ztráta:", state.loss)
Ekosystém kolem JAX
| Knihovna | Popis | GitHub | Flax | Modulární NN knihovna (high‑level) | https://github.com/google/flax | Haiku | DeepMind‑inspirovaná NN knihovna (object‑oriented) | https://github.com/deepmind/dm-haiku | Optax | Sada optimalizačních algoritmů a schedulérů | https://github.com/google-deepmind/optax | Equinox | „PyTorch‑like“ API s podporou JAX, zaměřená na jednoduchost | https://github.com/patrick-kidger/equinox | Chex | Testovací a debugging nástroje pro JAX kód | https://github.com/google/chex | JAX‑MD | Simulace molekulární dynamiky a fyzikální modely | https://github.com/google/jax-md |
|---|
Porovnání JAX vs. PyTorch vs. TensorFlow
| Kritérium | JAX | PyTorch | TensorFlow | Výkon (XLA) | Vysoký – kompilace do XLA (GPU/TPU) | Dobrá – TorchScript, ale méně optimalizované pro TPU | Dobrá – TF‑XLA, ale složitější API | Autodiff | Reverse‑mode, forward‑mode, vmap, pmap, grad | Autograd, torch.autograd, torch.jit | tf.GradientTape, tf.function | Funkcionální styl | Povinný (čisté funkce) – usnadňuje transformace | Volitelný (imperativní) | Volitelný (imperativní + graph) | Distribuce | pmap, pjit, mesh sharding – nativní podpora | torch.distributed, torch.nn.parallel | tf.distribute.Strategy | Ekosystém | Flax, Haiku, Optax, Equinox – rychle rostoucí | torchvision, torchtext, Lightning | Keras, tf.keras, TFLite | Kompatibilita s TPU | Native (XLA backend) – nejlepší volba | Experimentální (via torch_xla) | Native (TF‑XLA) | Learning curve | Střední – nutnost pochopit funkcionální paradigm | Nízká – podobné NumPy/Python | Střední – graph vs. eager |
|---|
Tipy a best practices
Používejte jit – i pro malé funkce, aby se XLA optimalizoval. Vektorizujte s vmap – nahraďte for‑loops, získáte SIMD‑výhody. Rozdělujte data s pmap – pokud máte více zařízení, trénujte paralelně. Explicitní typy – při práci s TPU specifikujte jnp.bfloat16 nebo jnp.float8 pro vyšší propustnost. Kontrola paměti – jax.debug.print a jax.profiler pomáhají najít memory leaks. Reproducibilita – vždy inicializujte PRNGKey a předávejte jej explicitně. Kompatibilita s numpy – jax.numpy je drop‑in, ale některé funkce (např. np.linalg.eig) nejsou podporovány – použijte ekvivalenty z jax.scipy.
Budoucnost JAX
JAX 3.0+ – podpora global device arrays (GDA) a asynchronní sharding. Sparse & Structured Matrices – rozšířená podpora pro sparse operace, klíčové pro LLM. Compiler‑level optimizations – další vylepšení XLA (např. fusion, tiling) a podpora float8 na GPU/TPU. Interoperabilita s PyTorch – projekty jako torch2jax a jax2torch usnadní migraci kódu. Rozšířený ekosystém – nové knihovny pro probabilistické programování (např. NumPyro) a diferenciální fyziku (např. Diffrax).
== Závěr == JAX představuje moderní, výkonný a flexibilní nástroj pro výzkumníky i inženýry, kteří potřebují rychlé diferenciální výpočty a škálovatelnost napříč CPU, GPU i TPU. Díky čistému funkcionálnímu přístupu, silnému ekosystému (Flax, Optax, Haiku) a nativní podpoře pro XLA je JAX dnes jedním z hlavních pilířů vývoje velkých modelů (transformery, diffusion, reinforcement learning). Pokud chcete maximalizovat výkon na TPU, JAX + Flax je momentálně nejefektivnější cesta.
Odkazy a literatura
[[https://github.com/google/jax|JAX – GitHub repository]] [[https://jax.readthedocs.io|JAX Documentation]] [[https://flax.readthedocs.io|Flax – Neural network library]] [[https://optax.readthedocs.io|Optax – Gradient processing & optimizers]] [[https://github.com/google/chex|Chex – Testing utilities for JAX]] [[https://arxiv.org/abs/1910.01408|“JAX: composable transformations of Python+NumPy programs” – arXiv 2019]] [[https://cloud.google.com/tpu/docs/jax-quickstart|Google Cloud – JAX on TPU Quickstart]] [[https://github.com/deepmind/dm-haiku|Haiku – DeepMind’s neural network library]] [[https://github.com/google/jax-md|JAX‑MD – Molecular dynamics library]]
