Comment by robrenaud

6 days ago

I worked on a similiar problem about a year ago, on large dense models.

https://www.lesswrong.com/posts/PkeB4TLxgaNnSmddg/scaling-sp...

In both cases, the goal is to actually learn a concrete circuit inside a network that solves specific Python next-token prediction tasks. We each end up with a crisp wiring diagram saying “these are the channels/neurons/heads that implement this particular bit of Python reasoning.”

Both projects cast circuit discovery as a gradient-based selection problem over a fixed base model. We train a mask that picks out a sparse subset of computational nodes as “the circuit,” while the rest are ablated. Their work learns masks over a weight-sparse transformer; ours learns masks over SAE latents and residual channels. But in both cases, the key move is the same: use gradients to optimize which nodes are included, rather than relying purely on heuristic search or attribution patching. Both approaches also use a gradual hardening schedule (continuous masks that are annealed or sharpened over time) so that we can keep gradients useful early on, then spend extra compute to push the mask towards a discrete, minimal circuit that still reproduces the model’s behavior.

The similarities extend to how we validate and stress-test the resulting circuits. In both projects, we drill down enough to notice “bugs” or quirks in the learned mechanism and to deliberately break it: by making simple, semantically small edits to the Python source, we can systematically cause the pruned circuit to fail and those failures generalize to the unpruned network. That gives us some confidence that we’re genuinely capturing the specific mechanism the model is using.