JAX是一个TensorFlow的简化库,它结合了Autograd和XLA,专门用于高性能机器学习研究。
凭借Autograd,JAX可以求导循环、分支、递归和闭包函数,并且它可以进行三阶求导。通过grad,它支持自动模式反向求导(反向传播)和正向求导,且二者可以任何顺序任意组合。
得力于XLA,可以在GPU和TPU上编译和运行NumPy程序。默认情况下,编译发生在底层,库调用实时编译和执行。但是JAX还允许使用单一函数 APIjit将Python函数及时编译为XLA优化的内核。编译和自动求导可以任意组合,因此可以在Python环境下实现复杂的算法并获得最大的性能。
import jax.numpy as npfrom jax import grad, jit, vmapfrom functools import partialdef predict(params, inputs): for W, b in params: outputs = np.dot(inputs, W) + b inputs = np.tanh(outputs) return outputsdef logprob_fun(params, inputs, targets): preds = predict(params, inputs) return np.sum((preds - targets)**2)grad_fun = jit(grad(logprob_fun)) # compiled gradient evaluation functionperex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grads更深入地看,JAX实际上是一个可扩展的可组合函数转换系统,grad和jit都是这种转换的实例。
评论