LSTM Gates Visualizer
Interactive visualization of Long Short-Term Memory gates and cell state dynamics
Understanding LSTM Networks
Long Short-Term Memory (LSTM) networks are a special type of Recurrent Neural Network capable of
learning
long-term dependencies. Unlike standard RNNs, LSTMs have a cell state that runs through the entire
chain,
with gates that regulate what information to keep or discard.
Key Components:
- Forget Gate (Orange): Decides what information to discard from the cell state
using sigmoid activation
- Input Gate (Green): Decides what new information to store in the cell state
- Output Gate (Blue): Decides what to output based on the cell state
- Cell State: The memory that runs through time, carrying long-term information
- Hidden State: The output at each timestep, used for predictions and next step
How it works: At each timestep, the LSTM takes an input and the previous hidden
state.
The forget gate decides what to remove from memory, the input gate adds new information, and the
output
gate produces the new hidden state. This allows LSTMs to remember important information over many
timesteps
and forget irrelevant details.
LSTM Cell Architecture & Gate Activations
Gate Activity Heatmaps Over Time
Darker = Higher activation (closer to 1)
Lighter = Lower activation (closer to 0)
Cell State vs Hidden State Evolution
Cell State (Long-term Memory)
function drawStateEvolution() {
stateCtx.clearRect(0, 0, stateCanvas.width, stateCanvas.height);
if (cellStateHistory.length === 0) {
stateCtx.fillStyle = '#4a5568';
stateCtx.font = '14px Arial';
stateCtx.textAlign = 'center';
stateCtx.fillText('Cell state and hidden state evolution will appear here',
stateCanvas.width/2, stateCanvas.height/2);
return;
}
const padding = 50;
const w = stateCanvas.width - 2 * padding;
const h = stateCanvas.height - 2 * padding;
// Draw axes
stateCtx.strokeStyle = '#2d3748';
stateCtx.lineWidth = 2;
stateCtx.beginPath();
stateCtx.moveTo(padding, padding);
stateCtx.lineTo(padding, padding + h);
stateCtx.lineTo(padding + w, padding + h);
stateCtx.stroke();
// Calculate max magnitude for scaling
let maxMag = 0;
cellStateHistory.forEach(state => {
state.forEach(val => {
maxMag = Math.max(maxMag, Math.abs(val));
});
});
hiddenStateHistory.forEach(state => {
state.forEach(val => {
maxMag = Math.max(maxMag, Math.abs(val));
});
});
maxMag = Math.max(maxMag, 1);
// Draw each unit's evolution
const timesteps = cellStateHistory.length;
// Cell state (multiple thin lines)
for (let unit = 0; unit < cellStateSize; unit++) { stateCtx.strokeStyle=`rgba(156, 39, 176, ${0.3 + 0.7 * unit /
cellStateSize})`; stateCtx.lineWidth=1.5; stateCtx.beginPath(); for (let t=0; t < timesteps; t++) { const
x=padding + (t / (timesteps - 1)) * w; const value=cellStateHistory[t][unit]; const y=padding + h/2 - (value /
maxMag) * (h/2 * 0.8); if (t===0) stateCtx.moveTo(x, y); else stateCtx.lineTo(x, y); } stateCtx.stroke(); } //
Hidden state (multiple thin lines) for (let unit=0; unit < cellStateSize; unit++) {
stateCtx.strokeStyle=`rgba(0, 188, 212, ${0.3 + 0.7 * unit / cellStateSize})`; stateCtx.lineWidth=1.5;
stateCtx.beginPath(); for (let t=0; t < timesteps; t++) { const x=padding + (t / (timesteps - 1)) * w; const
value=hiddenStateHistory[t][unit]; const y=padding + h/2 - (value / maxMag) * (h/2 * 0.8); if (t===0)
stateCtx.moveTo(x, y); else stateCtx.lineTo(x, y); } stateCtx.stroke(); } // Draw zero line
stateCtx.strokeStyle='#cbd5e0' ; stateCtx.lineWidth=1; stateCtx.setLineDash([5, 5]); stateCtx.beginPath();
stateCtx.moveTo(padding, padding + h/2); stateCtx.lineTo(padding + w, padding + h/2); stateCtx.stroke();
stateCtx.setLineDash([]); // Labels stateCtx.fillStyle='#2d3748' ; stateCtx.font='bold 12px Arial' ;
stateCtx.textAlign='center' ; stateCtx.fillText('Timestep', padding + w/2, stateCanvas.height - 10);
stateCtx.save(); stateCtx.translate(15, padding + h/2); stateCtx.rotate(-Math.PI/2); stateCtx.fillText('State
Value', 0, 0); stateCtx.restore(); } function drawAll() { drawCellArchitecture(); drawHeatmaps();
drawStateEvolution(); } function updateStats() {
document.getElementById('timestepValue').textContent=`${currentStep} / ${sequence.length}`; if (currentStep> 0)
{
const cellNorm = vectorNorm(cellState);
const hiddenNorm = vectorNorm(hiddenState);
document.getElementById('cellNormValue').textContent = cellNorm.toFixed(2);
document.getElementById('hiddenNormValue').textContent = hiddenNorm.toFixed(2);
document.getElementById('currentTokenValue').textContent =
currentStep <= sequence.length ? sequence[currentStep - 1] : '-' ; } else {
document.getElementById('cellNormValue').textContent='0.00' ;
document.getElementById('hiddenNormValue').textContent='0.00' ;
document.getElementById('currentTokenValue').textContent='-' ; } } function updateCellSize() {
cellStateSize=parseInt(document.getElementById('cellSizeSlider').value);
document.getElementById('cellSizeValue').textContent=`${cellStateSize} units`; if (!isProcessing) {
resetLSTM(); } } function updateSpeed() { const
speed=parseInt(document.getElementById('speedSlider').value); const speeds={ 1: 'Slow (400ms)' ,
2: 'Medium (250ms)' , 3: 'Fast (150ms)' }; const delays={ 1: 400, 2: 250, 3: 150 };
document.getElementById('speedValue').textContent=speeds[speed]; animationSpeed=delays[speed]; if
(isProcessing) { pauseProcessing(); startProcessing(); } } function loadTask(taskIndex) {
document.getElementById('sequenceInput').value=exampleTasks[taskIndex]; resetLSTM(); } function
startProcessing() { if (isProcessing) return; if (currentStep===0) { const
input=document.getElementById('sequenceInput').value; sequence=tokenize(input); initializeLSTM(); } if
(currentStep>= sequence.length) {
return;
}
isProcessing = true;
document.getElementById('statusIndicator').textContent = 'Processing sequence...';
document.getElementById('statusIndicator').style.background = '#d4edda';
document.getElementById('statusIndicator').style.color = '#155724';
processingInterval = setInterval(() => {
stepForward();
if (currentStep >= sequence.length) {
pauseProcessing();
document.getElementById('statusIndicator').textContent = 'Sequence complete';
document.getElementById('statusIndicator').style.background = '#d1ecf1';
document.getElementById('statusIndicator').style.color = '#0c5460';
}
}, animationSpeed);
}
function pauseProcessing() {
isProcessing = false;
if (processingInterval) {
clearInterval(processingInterval);
processingInterval = null;
}
if (currentStep < sequence.length && currentStep> 0) {
document.getElementById('statusIndicator').textContent = 'Paused';
document.getElementById('statusIndicator').style.background = '#fff3cd';
document.getElementById('statusIndicator').style.color = '#856404';
}
}
function stepForward() {
if (currentStep === 0) {
const input = document.getElementById('sequenceInput').value;
sequence = tokenize(input);
initializeLSTM();
}
if (currentStep >= sequence.length) return;
processToken(sequence[currentStep]);
currentStep++;
updateStats();
drawAll();
}
function resetLSTM() {
pauseProcessing();
const input = document.getElementById('sequenceInput').value;
sequence = tokenize(input);
initializeLSTM();
updateStats();
drawAll();
document.getElementById('statusIndicator').textContent = 'Ready to process sequence';
document.getElementById('statusIndicator').style.background = '#e3f2fd';
document.getElementById('statusIndicator').style.color = '#1565c0';
}
// Initialize
resetLSTM();