Math 594: Numerical Computation with JAX

Turning theoretical math into high-performance code. This course provides a deep dive into JAX—a cutting-edge, high-performance numerical computing library in Python. We explore the potential for implementing scalable algorithms for large-scale data in fields such as probabilistic machine learning, scientific ML, and AI.

Course Information


A Cooperative Learning Environment

This course is designed as a cooperative learning community. Success in Math 594 depends on students actively contributing to the collective mastery of JAX. Beyond standard lectures, students are expected to collaborate, share implementation strategies, and lead discussions on specialized tools. We function as a research-and-development cohort where peer feedback and collaborative troubleshooting are central to the experience.

Course Description

This course emphasizes practical application for advanced computation. We move beyond standard Python to master JAX’s core functional programming paradigms through:

Learning Objectives

By the end of this course, students will be able to:

  1. Implement scalable algorithms using JAX’s functional programming model.
  2. Optimize performance via Just-In-Time (JIT) compilation and vectorization.
  3. Develop complex models using automatic differentiation and custom primitives.
  4. Contribute to the academic community by explaining and documenting technical concepts for peers.

Prerequisites


Evaluation & Projects


Logistics & AI Policy


Course Schedule

WeeksModuleKey Concepts
1OverviewWhy JAX? The Case for Differentiable Programming
2–4JAX 101Functional programming, Arrays, and JIT compilation

Resources