The first step...
I have been working on AI projects since 2017 primarily as the token software engineer in a group of AI researchers. My roles have consisted of taking models trained by research teams and integrating them into applications that real people can use. These experiences have given me a degree of exposure to how AI is built, developed, and run that most software engineers haven't had. That said, not having a formal background in AI or working much on the model development side there has been a gap in my understanding about how the tools that we use work. The time has come for me to close that gap.
Given that I'm still a working engineer and I don't have time to stop to go back to grad school, I decided to take advantage of the huge improvements in LLMs over the last couple of years to augment my education. I used Claude to help me construct a graduate level course on AI inference covering the latest optimization techniques being used in the industry right now. I am currently working through the first phase of that course and I will be documenting my learning journey through a series of blog posts of which this is the first one. If this sounds interesting to you check back in regularly for updates as I work through the material with Claude. Okay enough background, let's get to why we're here.
Setup
As part of the first phase of my self-directed course I have gone back to the basics of how neural networks are constructed. Starting with the simplest form of neural networks using scalar (single) values only. Specifically I have been working through Andrej Karpathy's Neural Networks: Zero to Hero series.
In this series Karpathy introduces a simple autograd engine that works on scalar values called micrograd. This library is written in Python and demonstrates at the most simple level how libraries like PyTorch and Jax work conceptually when building networks and automatically computing gradients (hence the term autograd) of those networks.
Being the avid Rust programmer that I am, I decided as an exercise for myself to implement Karpathy's micrograd library in Rust. I have named this library crabwalk.
Values and Ownership
At the simplest level micrograd constructs a computation graph for each node and uses that graph to automatically compute the gradients for each node using the chain rule from calculus. You may have heard of this previously under the term backpropagation.
For the first part of porting micrograd to Rust I focused on laying the foundations of the Value struct and implementing the basic Add, Mul, and backward operations. This gives us a functioning but simple autograd engine from which we can build on to add more complex operations.
Due to the ownership constraints of Rust special care had to be taken around the design of the library internals, specifically the internals of the Value struct. The primary challenge here comes from the fact that Value nodes can be the children of multiple other Value nodes. The Value node needs to be able to be "owned" by multiple entities while retaining the ability to be mutated via the backwards pass.
let a = Value::new(10.0);
let b = Value::new(20.0);
let c: Value = &a * &b; // a is a child of c
let d: Value = &c + &a; // a is also a child of d
In the above example we have a Value node a that through addition and multiplication is a child of both the c and d Value nodes. Child in this context means that a contributed to the output of both c and d. This is important to track because to compute the gradient for each element in the graph we need to walk from the root (output) to the leaves (input) computing the gradient using the chain rule. Importantly to compute the correct gradient for node a in our example we need the gradients from both the d and c nodes.
In the Python implementation this just works since all objects are heap allocated by default we can include them as children of multiple Value nodes without issue. In the Rust port this won't be the case because of Rust's ownership model. This attribute of Value nodes needs to be specifically designed into the library.
In crabwalk this was handled by splitting the Value node into two structs. A public Value struct that is the interface that users of the library will interact with and a private ValueInner struct that contains the data for each node. The Value struct wraps a ValueInner struct in an Rc<RefCell<ValueInner>>. Wrapping the ValueInner in an Rc<RefCell<_>> gives us the properties that we need to have multiple ownership and interior mutability.
#[derive(Clone)]
pub struct Value(Rc<RefCell<ValueInner>>);
Rc is a reference counted wrapper that does not drop the inner type until all references are 0. This gives us the ability to cheaply clone the Value struct without duplicating the data. Each clone simply increments the Rc reference counter. If you are familiar with Rust you may be asking yourself why I chose to use Rc instead of Arc. Arc provides the same properties of Rc with the notable addition that it is thread-safe because it utilizes an atomic counter (the A in Arc). Rc was chosen in this case instead of Arc because we don't need the thread safety properties of Arc and therefore don't need to pay the overhead of atomics. Constructing a computation graph, topological sort, and the backwards pass are all inherently serial operations that don't benefit from any parallelism making Rc the better choice.
The RefCell wrapper gives us the interior mutability that we need to update gradients of each Value in the backwards pass. Since a single Value node can be a child of multiple other Value nodes, as we saw in the earlier example, we need multiple "owners" to be able to mutate the inner data when updating their gradient so that all the gradients that flow into a specific Value node are accumulated properly.
let a = Value::new(10.0); // Correct gradient for a needs to accumulate from both operations that create c and d
let b = Value::new(20.0);
let c: Value = &a * &b; // compute gradient of a here
let d: Value = &c + &a; // and here as well
Operations
Now that we have the lay of the land for the ownership pattern let's talk about the addition and multiplication operations on Value. To properly compute the gradient of all the values in the computation graph we need to know what operations created those values as the operation used will effect how we compute the gradient. Sometimes Value nodes are created directly (a and b in our example) and other times new Value nodes are created as the result of an arithmetic operation (c and d in our example). Let's take a look at the implementation we need for addition.
impl Add for Value {
type Output = Self;
fn add(self, other: Self) -> Self {
let new_val = self.0.borrow().data + other.0.borrow().data;
let out = Value(Rc::new(RefCell::new(ValueInner::new(
new_val,
vec![self.clone(), other.clone()],
))));
{
let mut inner = out.0.borrow_mut();
// Need to use a Weak reference here to prevent a memory leak due to the
// capture of `out`.
let out_weak = Rc::downgrade(&out.0);
inner.backward = Some(Box::new(move || {
let out_inner = out_weak.upgrade().unwrap();
let grad = out_inner.borrow().grad;
self.0.borrow_mut().grad += grad;
other.0.borrow_mut().grad += grad;
}));
inner.op = Some("+");
}
out
}
}
Implementing the std::ops::Add trait in Rust allows us to use the + operator to add two values together of a custom type. We can use this custom implementation to do our book keeping that we'll need for the backwards pass to compute the gradients. As you can see in the above example, we add the underlying data (f64) and create a new Value node with the new data as well as self and other set as children. We then need to set the definition for the backwards pass for this operation on the newly created Value node. For addition the differentiation is the accumulation of the gradient of the output node. Increasing either input by a small amount increases the output by exactly the same amount, so the local gradient is 1 for both inputs.
IMPORTANT NOTE: Because a single Value node may be used multiple times we need to accumulate += the gradient and not simply set the gradient using =. Failure to do this would result in an incorrect gradient for the node.
The same conceptual approach is used for the std::ops::Mul implementation that allows us to use the * operator with the only difference being how the backward function is defined. For multiplication the gradient computation is the accumulation of the other term's data multiplied by the output node's gradient. In english, if you have the expression c = a * b when you change the value of a a small amount how much c changes depends on how large b is and visa versa. See below for the concrete implementation in crabwalk.
//...
{
let mut inner = out.0.borrow_mut();
// Need to use a Weak reference here to prevent a memory leak due to the
// capture of `out`.
let out_weak = Rc::downgrade(&out.0);
inner.backward = Some(Box::new(move || {
let out_inner = out_weak.upgrade().unwrap();
let grad = out_inner.borrow().grad;
let self_data = self.0.borrow().data;
let other_data = other.0.borrow().data;
self.0.borrow_mut().grad += other_data * grad;
other.0.borrow_mut().grad += self_data * grad;
}));
inner.op = Some("*");
}
//...
Running in Circles (the reference problem)
If you've been paying attention you'll have noticed the use of Rc::downgrade and a Weak reference for the backward function implementation for both the Add and Mul traits. I want to take a moment now and explain why. The reason that we need to downgrade and use a Weak reference here is so that we don't end up in a cycle of strong references that prevent a Value node from ever being cleaned up.
Looking closely at the implementation, you'll see that we borrow the ValueInner struct from the newly created Value node as mutable and then create a new closure that will be set as the backward function on the new Value node. To properly compute the gradients for the children (self and other) we need to utilize the gradient of the Value node we just created. To make sure we are using the most up-to-date value for grad on the output node when backward is called we need to capture the node itself. If we just captured the value of grad when we create this function it would always be set to the default value of 0.0 resulting in an incorrect gradient computation.
This puts us in a tight spot though. With Rc we could easily call clone and create another pointer to Value that can be moved inside the closure but this closure is being set on the Value node we just cloned. This means that there would be a cyclic reference in this Value node where the counter in Rc would never reach 0 (because of the captured Value in the closure) and therefore the data would never get cleaned up. This is perfectly memory safe and the Rust compiler will let you do it but it would lead to a nasty memory leak which we don't want.
The solution here is to use a Weak reference instead of a Strong reference to the new Value node we create in the operation. A Weak reference does not increment the reference counter so when this data is captured by the closure we won't have any issues with cyclic references. When backward is called, we upgrade the Weak reference to a strong reference so that the underlying value won't be dropped while we're using it keeping this operation memory safe. Crucially this upgraded reference only lasts for the scope of the closure and is dropped (the counter is decremented) at the end of the closure. A detailed annotation of the code is below.
{
let mut inner = out.0.borrow_mut();
let out_weak = Rc::downgrade(&out.0); // Create a `Weak` reference
inner.backward = Some(Box::new(move || {
// Weak reference is captured by the closure
let out_inner = out_weak.upgrade().unwrap(); // Weak reference is upgrade to a Strong reference
let grad = out_inner.borrow().grad;
self.0.borrow_mut().grad += grad;
other.0.borrow_mut().grad += grad;
})); // Strong reference is dropped here
inner.op = Some("+");
}
Backwards
Now we get to the auto part of autograd. The backwards method on the Value node is how the gradient is automatically calculated after the computation graph is constructed. This is done by iterating all of the nodes from the output to the input computing the gradient along the way. You can think of this as a tree that starts at the output node as the root and flows towards the input at the leaves.
impl Value {
// ...
pub fn backward(&self) {
let mut visited = HashSet::new();
let mut topo = self.build_topo(&mut visited);
topo.reverse();
self.0.borrow_mut().grad = 1.0;
for node in topo {
if let Some(backward_fn) = &node.0.borrow().backward {
backward_fn();
}
}
}
/// Topological sort of the computation graph.
///
/// Recursively calling `build_topo` on children until a node
/// is reached that has no children. Each call returns a `Vec<Value>`
/// that is merged with the parent's `Vec<Value>` before being returned
/// up to the next level eventually ending with the root node.
fn build_topo(&self, visited: &mut HashSet<*mut ValueInner>) -> Vec<Value> {
let mut topo = vec![];
if !visited.contains(&self.0.as_ptr()) {
visited.insert(self.0.as_ptr());
for child in &self.0.borrow().children {
let mut child_topo = child.build_topo(visited);
topo.append(&mut child_topo);
}
topo.push(self.clone());
}
topo
}
}
In our example computation graph, the gradient of d needs to be computed first, then c, then a, etc. etc. To get the correct ordering of the Value nodes in the graph we perform a topological sort using post-order DFS to build the ordering of Value nodes. The post-order DFS is then reversed to start from the root and move towards the leaves.
let a = Value::new(10.0); // a is a leaf
let b = Value::new(20.0); // b is a leaf
let c: Value = &a * &b; // c is the parent of a and b
let d: Value = &c + &a; // d is the output node a.k.a root node it is the parent of c and a.
During the construction of the sorted list it's important that each node only appears once in the list. Since we don't have any sort of unique identifier for Value nodes we need some way to know if we've already seen a node and placed it in the list. This is solved with a combination of a HashSet and the pointer addresses of each Rc inside a Value node. Since the Rc is heap allocated and there is only ever one location in memory of the node data so the pointer address of the Rc can be used to determine if a node has already been added to the list.
With the list constructed all there is left to do is to iterate all the entries calling the backward function on each node if it's present. This will in turn call the closure that was constructed in the implementation of the operation to correctly differentiate the child values.
pub fn backward(&self) {
let mut visited = HashSet::new();
let mut topo = self.build_topo(&mut visited);
topo.reverse();
self.0.borrow_mut().grad = 1.0;
for node in topo {
if let Some(backward_fn) = &node.0.borrow().backward {
backward_fn();
}
}
}
Conclusion
That's it! That's a very basic but complete implementation of an autograd engine for scalar values. There are more operations that I'll be adding to crabwalk to reach complete parity with micrograd but the pattern that we've established here with the Value struct will hold for the rest of the operation implementations. Below is a very simple computation graph that correctly computes the gradient of a to prove that our library works.
use crabwalk::Value;
// Note: The comment ordering here follows the logic from the perspective of the backwards pass and not the foward pass. Read the code comments in reverse order starting from bottom to top.
fn main() {
let a = Value::new(2.0); // No gradient computation, this is input
let b = Value::new(3.0); // No gradient computation, this is input
let c = &a * &b; // Local gradient of a is 3.0 (value of b * gradient of c).
// The total gradient of a is 4.0 because the local
// gradient is accumulated with the previous gradient of 1.0
let d = &c + &a; // Local gradient of a is 1.0. The total gradient of a is
// also 1.0 because this is the first operation up from the
// root (the default gradient of d is 1.0)
d.backward(); // Gradient of d starts at 1.0
println!("a grad: {:#?}", a.grad());
}
$ cargo run --example add_mul
$ a grad: 4.0
I'll be continuing to expand upon this port of micrograd and diving deeper into the depths of modern AI inference from the ground up. Check out the GitHub link below for the full crabwalk implementation.
Check back for future blog posts! As always if you liked this please share with others that may also find this interesting!
See you in the next one!