Math 594: Numerical Computation with JAX
Turning theoretical math into high-performance code. This course provides a deep dive into JAX—a cutting-edge, high-performance numerical computing library in Python. We explore the potential for implementing scalable algorithms for large-scale data in fields such as probabilistic machine learning, scientific ML, and AI.
Course Information
- Instructor: Shen-Ning Tung (tung@math.nthu.edu.tw)
- Lecture Time: Wednesdays, 10:10 AM – 12:00 PM and Fridays, 2:20 PM – 3:10 PM
- Office Hours: By appointment
- Target Audience: Students interested in Differentiable Programming, Scientific ML, and Scalable AI.
A Cooperative Learning Environment
This course is designed as a cooperative learning community. Success in Math 594 depends on students actively contributing to the collective mastery of JAX. Beyond standard lectures, students are expected to collaborate, share implementation strategies, and lead discussions on specialized tools. We function as a research-and-development cohort where peer feedback and collaborative troubleshooting are central to the experience.
Course Description
This course emphasizes practical application for advanced computation. We move beyond standard Python to master JAX’s core functional programming paradigms through:
- Functional Programming: Mastering pure functions and immutable state.
- JIT Compilation: Leveraging XLA to accelerate Python code.
- Automatic Differentiation: Utilizing
gradfor complex gradient-based optimization.
Learning Objectives
By the end of this course, students will be able to:
- Implement scalable algorithms using JAX’s functional programming model.
- Optimize performance via Just-In-Time (JIT) compilation and vectorization.
- Develop complex models using automatic differentiation and custom primitives.
- Contribute to the academic community by explaining and documenting technical concepts for peers.
Prerequisites
- Mathematics: Familiarity with Multivariate Calculus and Linear Algebra.
- Programming: Proficiency in Python (NumPy experience is highly recommended).
Evaluation & Projects
- Coding Assignments (25%): Problem sets focusing on implementing and optimizing numerical algorithms.
- In-Class Presentations (25%): Students will take the lead in delivering and introducing essential concepts and specialized tools relevant to project development, fostering peer-to-peer learning.
- Final Project (50%): A substantial project where students implement, document, and present a sophisticated algorithm or research application. Collaborative projects are encouraged.
Logistics & AI Policy
- Communication: All announcements and technical discussions occur on Discord. Students are encouraged to help troubleshoot peers’ queries in public channels.
- Submissions: All work must be submitted in Markdown or Jupyter Notebook format via GitHub to encourage version control best practices.
- AI Usage: Unrestricted use of AI tools is permitted. No disclosure is required; however, students remain fully responsible for the accuracy, logic, and originality of their submitted work.
Course Schedule
| Weeks | Module | Key Concepts |
|---|---|---|
| 1 | Overview | Why JAX? The Case for Differentiable Programming |
| 2–4 | JAX 101 | Functional programming, Arrays, and JIT compilation |
Resources
- The Elements of Differentiable Programming
- Official JAX Documentation
- Instructor-provided Course Notes.