🧠

Feature Visualization

Neuromatch Academy Deep Learning Track

View on GitHub →

📖 Background

This project was created during my time at Neuromatch Academy's Deep Learning track — an intensive computational neuroscience and deep learning summer school.

I was fascinated by the question: What's actually happening inside these networks? This led me down the path of mechanistic interpretability — understanding the internal mechanisms of neural networks by visualizing their learned representations.

Looking back, I was exploring mech interp before it became the hot topic it is today! 🔥

🔬 The Question

What do the intermediate representations look like, and how do they evolve through the network?

1. Learned Convolutional Filters

The first layer learns to detect basic features like edges, gradients, and textures. Deeper layers combine these into more complex, abstract features.

Learned Filters

Observations

  • • Layer 1 filters detect simple patterns (edges, gradients)
  • • Deeper layers combine these into complex features
  • • Filters show diverse orientations and frequencies

2. Feature Maps (Activations)

Each filter produces an activation map showing where it "fires" in the input. Watch how spatial resolution decreases but feature richness increases through layers.

Activations

Observations

  • • Early layers preserve spatial structure
  • • Deeper layers: smaller spatial resolution, richer features
  • • Different channels respond to different input aspects

3. Individual Channel Activations

Each channel in a layer responds to different features. Some act as edge detectors, others respond to textures or patterns.

Channel Activations

📐 Technical Details

Architecture

SimpleCNN:
├── Conv(32, 3×3) + ReLU + Pool
├── Conv(64, 3×3) + ReLU + Pool
├── Conv(64, 3×3) + ReLU
├── Flatten → Dense(128)
└── Output(10)

Stack

  • Framework: JAX/Flax
  • Hardware: TPU v4
  • Dataset: MNIST
  • Context: Neuromatch Academy
🔄

Part 2: RNN/LSTM Visualization

What happens inside recurrent networks?

4. Hidden State Evolution

As the LSTM processes each element in the sequence, its hidden state evolves to encode relevant information.

Hidden Evolution

5. LSTM Gate Activations

The three gates control information flow through the network:

Forget Gate

What to discard from memory

Input Gate

What new info to store

Output Gate

What to output

Gates

6. Cell State (Memory) Dynamics

The cell state carries information across the sequence — the LSTM's "memory".

Cell State

7. RNN Summary

RNN Summary

🎯 Task & Results

Task: Count if a binary sequence has more 1s than 0s
Test Accuracy: 100% — The LSTM learns to count!

💡 Key Findings

CNNs:

  1. 1. Hierarchical representations (edges → patterns)
  2. 2. Different filters specialize in different features
  3. 3. Spatial relationships preserved

RNNs/LSTMs:

  1. 1. Gates control what to remember/forget
  2. 2. Cell state maintains long-term memory
  3. 3. Hidden state encodes sequence context