JAX TensorFlow 简化库开源项目

我要开发同款
匿名用户2018年12月14日
40阅读
开发技术Python
所属分类人工智能、机器学习/深度学习
授权协议Apache-2.0

作品详情

JAX是一个TensorFlow的简化库,它结合了Autograd和XLA,专门用于高性能机器学习研究。

凭借Autograd,JAX可以求导循环、分支、递归和闭包函数,并且它可以进行三阶求导。通过grad,它支持自动模式反向求导(反向传播)和正向求导,且二者可以任何顺序任意组合。

得力于XLA,可以在GPU和TPU上编译和运行NumPy程序。默认情况下,编译发生在底层,库调用实时编译和执行。但是JAX还允许使用单一函数 APIjit将Python函数及时编译为XLA优化的内核。编译和自动求导可以任意组合,因此可以在Python环境下实现复杂的算法并获得最大的性能。

import jax.numpy as npfrom jax import grad, jit, vmapfrom functools import partialdef predict(params, inputs):  for W, b in params:    outputs = np.dot(inputs, W) + b    inputs = np.tanh(outputs)  return outputsdef logprob_fun(params, inputs, targets):  preds = predict(params, inputs)  return np.sum((preds - targets)**2)grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation functionperex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads

更深入地看,JAX实际上是一个可扩展的可组合函数转换系统,grad和jit都是这种转换的实例。

声明:本文仅代表作者观点,不代表本站立场。如果侵犯到您的合法权益,请联系我们删除侵权资源!如果遇到资源链接失效,请您通过评论或工单的方式通知管理员。未经允许,不得转载,本站所有资源文章禁止商业使用运营!
下载安装【程序员客栈】APP
实时对接需求、及时收发消息、丰富的开放项目需求、随时随地查看项目状态

评论