
新智元报道
新智元报道
【新智元导读】JAX在最近的基准测试中的性能已经不声不响地超过了Pytorch和TensorFlow,也许未来会有更多的大模型诞生在这个平台上。谷歌在背后的默默付出终于得到了回报。




模型
最近,Keras团队为三个后端(TensorFlow、JAX、PyTorch)与原生PyTorch实现以及搭配TensorFlow的Keras 2进行了基准测试。 首先,他们为生成式和非生成式人工智能任务选择了一组主流的计算机视觉和自然语言处理模型:

对于模型的Keras版本,其采用了KerasCV和KerasNLP中已有的实现进行构建。而对于原生的PyTorch版本,则选择了网络上最流行的几个选项: - 来自HuggingFace Transformers的BERT、Gemma、Mistral - 来自HuggingFace Diffusers的StableDiffusion - 来自Meta的SegmentAnything 他们将这组模型称作「Native PyTorch」,以便与使用PyTorch后端的Keras 3版本进行区分。 他们对所有基准测试都使用了合成数据,并在所有LLM训练和推理中使用了bfloat16精度,同时在所有LLM训练中使用了LoRA(微调)。 根据PyTorch团队的建议,他们在原生PyTorch实现中使用了torch.compile(model, mode="reduce-overhead")(由于不兼容,Gemma和Mistral训练除外)。 为了衡量开箱即用的性能,他们使用高级API(例如HuggingFace的Trainer()、标准PyTorch训练循环和Keras model.fit()),并尽可能减少配置。 硬件配置

硬件配置
所有基准测试均使用Google Cloud Compute Engine进行,配置为:一块拥有40GB显存的NVIDIA A100 GPU、12个虚拟CPU和85GB的主机内存。 基准测试结果
基准测试结果

关键发现
发现1
发现2
发现3
发现4

结论
框架的性能在很大程度上取决于具体使用的模型。 Keras 3能够帮助为任务选择最快的框架,这种选择几乎总能超越Keras 2和PyTorch实现。 更为重要的是,Keras 3模型无需进行复杂的底层优化,即可提供卓越的开箱即用性能。 参考资料: https://keras.io/getting_started/benchmarks/






内容中包含的图片若涉及版权问题,请及时与我们联系删除
评论
沙发等你来抢