Uživatelské nástroje

Nástroje pro tento web


jax

= JAX – Moderní nástroj pro diferenciální programování a výpočetní akceleraci =

TOC

== Co je JAX? == JAX je open‑source knihovna od Google, která kombinuje NumPy‑like API s automatickým diferenciálním výpočtem a kompilací na různé hardwarové akcelerátory (CPU, GPU, TPU). Díky funkci jit (just‑in‑time) a transformacím jako grad, vmap, pmap a jit umožňuje psát čistý, funkcionální kód a nechat JAX optimalizovat a distribuovat výpočty.

== Historie a vývoj == | ^ Rok ^ | ^ Událost ^ | | 2018 | První veřejná verze JAX (v0.1) – zaměřená na automatické diferencování pomocí Autograd. | | 2019 | Přidání XLA (Accelerated Linear Algebra) backendu, podpora GPU a TPU. | | 2020 | Vydání v0.2 – zavedení jit, vmap a pmap. | | 2021 | JAX 1.0 – stabilní API, rozšířená podpora pro float8, bfloat16, a mixed‑precision. | | 2022 | Integrace s Flax, Haiku a Optax – ekosystém pro modelování, trénink a optimalizaci. | | 2023 | JAX 2.0 – podpora GPU‑TensorCore a TPU‑v5 instrukcí, rozšířené pjit a sharding. | | 2024 | Rozšířený JAX 3.0 preview – podpora distributed arrays (GDA) a sparse operací. |

== Základní koncepty == === 1. Funkcionální programování === JAX vyžaduje, aby funkce byly čisté (bez vedlejších efektů). To umožňuje:

  Transformace (grad, jit, vmap, pmap, pmap, pjit) – mohou být aplikovány na libovolnou funkci.
  Deterministické chování – důležité pro reprodukovatelnost experimentů.

2. Autodiff (automatické diferenciace)

  grad(f) – vrací funkci, která počítá gradient skalárního výstupu.
  value_and_grad(f) – vrací jak hodnotu, tak gradient.
  jacfwd, jacrev – Jacobian pomocí forward‑ nebo reverse‑mode.

=== 3. Just‑In‑Time kompilace (JIT) === jit(f) kompiluje funkci pomocí XLA a uloží optimalizovaný kód pro daný hardware. Při opakovaném volání je výkon až 10‑30× rychlejší než čistý NumPy.

=== 4. Vektorizace (vmap) === vmap(f) automaticky mapuje funkci přes batche, čímž eliminuje potřebu explicitních for‑loopů a umožňuje SIMD‑styl výpočty.

5. Paralelní a distribuovaná výpočty (pmap, pj​it)

  pmap(f) – paralelní mapování přes více zařízení (GPU/TPU) na jednom hostu.
  pjit(f, in_sharding, out_sharding) – explicitní sharding pro distribuované výpočty napříč clusterem.

== Instalace ==

Instalace z PyPI (CPU‑only)
 
pip install --upgrade "jax[cpu]"
Instalace s GPU (CUDA 12)
 
pip install --upgrade "jax[cuda12_cudnn89]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Instalace s TPU (v Cloud Shell)
 
pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 

== Jednoduchý příklad – lineární regrese ==

 import jax.numpy as jnp from jax import grad, jit, random
Generování syntetických dat
 
key = random.PRNGKey(0) X = random.normal(key, (1000, 3)) true_w = jnp.array([1.5, -2.0, 0.7]) y = X @ true_w + 0.1 * random.normal(key, (1000,))
Definice ztrátové funkce
 
def loss(w, X, y): preds = X @ w return jnp.mean((preds - y) ** 2)
Gradient a JIT‑ovaná verze
 
grad_loss = jit(grad(loss))
Optimalizace (SGD)
 
def sgd_step(w, lr, X, y): g = grad_loss(w, X, y) return w - lr * g
Trénink
 
w = jnp.zeros(3) lr = 0.01 for epoch in range(500): w = sgd_step(w, lr, X, y)
 
print("Naučené váhy:", w) 

Co se zde děje?

  Používáme jax.numpy (drop‑in replacement za NumPy).
  jit kompiluje gradientní výpočet jednou a pak ho znovu používá.
  Všechny operace jsou vektorové a běží na GPU/TPU, pokud jsou k dispozici.

Pokročilé funkce

=== 1. Vmap – batched inference ===

 def model(params, x): w, b = params return jnp.dot(x, w) + b
 
batched_model = jax.vmap(model, in_axes=(None, 0))
 
params = (jnp.array([0.2, -0.5]), 0.1) xs = jnp.arange(10).reshape(5, 2) # 5 příkladů, 2 dimenze print(batched_model(params, xs)) 

=== 2. Pmap – data‑parallel trénink na 8 TPU ===

 from jax import pmap
 
def step(state, batch): grads = grad(loss)(state['params'], batch['X'], batch['y']) new_params = jax.tree_util.tree_map(lambda p, g: p - 0.001 * g, state['params'], grads) return {'params': new_params}, None
Rozdělení dat na 8 zařízení
 
def shard(data): return data.reshape(8, -1, *data.shape[1:])
 
state = {'params': jnp.zeros((3,))} batch = {'X': shard(X), 'y': shard(y)} state, _ = pmap(step, axis_name='devices')(state, batch) 

=== 3. Pjit a sharding – distribuované matice ===

 from jax.experimental import pjit from jax.sharding import Mesh, PartitionSpec as P
 
mesh = Mesh(jax.devices(), ('data',)) @pjit def matmul(A, B): return jnp.dot(A, B)
 
A = jnp.arange(40964096).reshape(4096, 4096) B = jnp.arange(40964096).reshape(4096, 4096)
Sharding specifikace
 
A_sharded = jax.device_put(A, jax.sharding.NamedSharding(mesh, P('data', None))) B_sharded = jax.device_put(B, jax.sharding.NamedSharding(mesh, P(None, 'data')))
 
C = matmul(A_sharded, B_sharded) 

JAX a TPU – ideální dvojice

  XLA backend: JAX automaticky překládá operace do XLA, což je optimalizovaný kompilátor používaný i v Google TPU.
  bfloat16 a float8: Podpora těchto formátů umožňuje vyšší throughput a nižší spotřebu paměti na TPU.
  Sparsity: V JAX 3.0 je experimentální podpora pro sparse matice, což je důležité pro velké jazykové modely (LLM).

Příklad – trénink Transformeru na TPU v4

 import flax.linen as nn import optax from flax.training import train_state
Jednoduchý Transformer blok (zkrácený)
 
class SimpleTransformer(nn.Module): d_model: int = 512 n_head: int = 8 n_layer: int = 6
 
@nn.compact
def __call__(self, x):
    for _ in range(self.n_layer):
        x = nn.SelfAttention(num_heads=self.n_head,
                             qkv_features=self.d_model)(x)
        x = nn.Dense(self.d_model)(x)
    return x
 
Inicializace
 
rng = jax.random.PRNGKey(0) model = SimpleTransformer() params = model.init(rng, jnp.ones((1, 128, 512)))['params']
Optimizer (AdamW)
 
tx = optax.adamw(1e-4) state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
JIT‑ovaná tréninková smyčka
 
@jax.jit def train_step(state, batch): def loss_fn(params): logits = state.apply_fn({'params': params}, batch['inputs']) loss = optax.softmax_cross_entropy(logits, batch['targets']).mean() return loss grads = jax.grad(loss_fn)(state.params) return state.apply_gradients(grads=grads)
Příklad batchu (předpokládá se, že je již shardován na TPU)
 
batch = {'inputs': jnp.ones((8, 128, 512)), # 8 devices 'targets': jnp.ones((8, 128, 512))} state = train_step(state, batch) print("Tréninková ztráta:", state.loss) 

Ekosystém kolem JAX

Knihovna Popis GitHub Flax Modulární NN knihovna (high‑level) https://github.com/google/flax Haiku DeepMind‑inspirovaná NN knihovna (object‑oriented) https://github.com/deepmind/dm-haiku Optax Sada optimalizačních algoritmů a schedulérů https://github.com/google-deepmind/optax Equinox „PyTorch‑like“ API s podporou JAX, zaměřená na jednoduchost https://github.com/patrick-kidger/equinox Chex Testovací a debugging nástroje pro JAX kód https://github.com/google/chex JAX‑MD Simulace molekulární dynamiky a fyzikální modely https://github.com/google/jax-md

Porovnání JAX vs. PyTorch vs. TensorFlow

Kritérium JAX PyTorch TensorFlow Výkon (XLA) Vysoký – kompilace do XLA (GPU/TPU) Dobrá – TorchScript, ale méně optimalizované pro TPU Dobrá – TF‑XLA, ale složitější API Autodiff Reverse‑mode, forward‑mode, vmap, pmap, grad Autograd, torch.autograd, torch.jit tf.GradientTape, tf.function Funkcionální styl Povinný (čisté funkce) – usnadňuje transformace Volitelný (imperativní) Volitelný (imperativní + graph) Distribuce pmap, pjit, mesh sharding – nativní podpora torch.distributed, torch.nn.parallel tf.distribute.Strategy Ekosystém Flax, Haiku, Optax, Equinox – rychle rostoucí torchvision, torchtext, Lightning Keras, tf.keras, TFLite Kompatibilita s TPU Native (XLA backend) – nejlepší volba Experimentální (via torch_xla) Native (TF‑XLA) Learning curve Střední – nutnost pochopit funkcionální paradigm Nízká – podobné NumPy/Python Střední – graph vs. eager

Tipy a best practices

  Používejte jit – i pro malé funkce, aby se XLA optimalizoval.
  Vektorizujte s vmap – nahraďte for‑loops, získáte SIMD‑výhody.
  Rozdělujte data s pmap – pokud máte více zařízení, trénujte paralelně.
  Explicitní typy – při práci s TPU specifikujte jnp.bfloat16 nebo jnp.float8 pro vyšší propustnost.
  Kontrola paměti – jax.debug.print a jax.profiler pomáhají najít memory leaks.
  Reproducibilita – vždy inicializujte PRNGKey a předávejte jej explicitně.
  Kompatibilita s numpy – jax.numpy je drop‑in, ale některé funkce (např. np.linalg.eig) nejsou podporovány – použijte ekvivalenty z jax.scipy.

Budoucnost JAX

  JAX 3.0+ – podpora global device arrays (GDA) a asynchronní sharding.
  Sparse & Structured Matrices – rozšířená podpora pro sparse operace, klíčové pro LLM.
  Compiler‑level optimizations – další vylepšení XLA (např. fusion, tiling) a podpora float8 na GPU/TPU.
  Interoperabilita s PyTorch – projekty jako torch2jax a jax2torch usnadní migraci kódu.
  Rozšířený ekosystém – nové knihovny pro probabilistické programování (např. NumPyro) a diferenciální fyziku (např. Diffrax).

== Závěr == JAX představuje moderní, výkonný a flexibilní nástroj pro výzkumníky i inženýry, kteří potřebují rychlé diferenciální výpočty a škálovatelnost napříč CPU, GPU i TPU. Díky čistému funkcionálnímu přístupu, silnému ekosystému (Flax, Optax, Haiku) a nativní podpoře pro XLA je JAX dnes jedním z hlavních pilířů vývoje velkých modelů (transformery, diffusion, reinforcement learning). Pokud chcete maximalizovat výkon na TPU, JAX + Flax je momentálně nejefektivnější cesta.

Odkazy a literatura
  [[https://github.com/google/jax|JAX – GitHub repository]]
  [[https://jax.readthedocs.io|JAX Documentation]]
  [[https://flax.readthedocs.io|Flax – Neural network library]]
  [[https://optax.readthedocs.io|Optax – Gradient processing & optimizers]]
  [[https://github.com/google/chex|Chex – Testing utilities for JAX]]
  [[https://arxiv.org/abs/1910.01408|“JAX: composable transformations of Python+NumPy programs” – arXiv 2019]]
  [[https://cloud.google.com/tpu/docs/jax-quickstart|Google Cloud – JAX on TPU Quickstart]]
  [[https://github.com/deepmind/dm-haiku|Haiku – DeepMind’s neural network library]]
  [[https://github.com/google/jax-md|JAX‑MD – Molecular dynamics library]]
jax.txt · Poslední úprava: autor: admin