With the rise of generative AI systems like OpenAI's ChatGPT or Google's Gemini, I've become particularly interested in how they actually work. A quick search reveals that these systems are built on what's called a 'large language model' or LLM for short. An LLM is essentially a huge neural network, which is a mathematical model inspired by biological brains, that can be trained on massive amounts of data to acquire a sort of 'natural understanding' of what it's shown. Neural networks are used for things like classification of handwritten letters, face detection in images, and of course, language prediction in the case of LLMs.
But even knowing that, the question remains, how
exactly do you go from an input string like
At a high level, the whole system can be split into two major parts. First, the training of the model: This is where it 'learns' from existing data (in our case pre-written text). Second, the generation using the trained model: This is where the model takes an input (a seed string) and produces a continuation based on what it learned. However, at no point does the model actually handle bare text. Instead, everything is broken down into smaller parts, called token. A token can be anything from multiple words to a single character. Try clicking the 'toggle token' button in the demonstration above to see exactly how text is split.
To generate a fitting continuation of
Understanding the training process requires looking at how a Markov chain can even 'learn' to predict token sequences. Recall, a Markov chain is a stochastic process \(X_\mathbb{N} \subset S\), with its state space \(S\), such that \[ P(X_{n+1} = s \,|\, X_n,\ldots,X_1) = P(X_{n+1} = s \,|\, X_n) \] holds for all \(s \in S\). This means the chain only cares about its current state and not how it got there, which is both its biggest strength and its biggest limitation. If the state space is enumerable i.e. \(S = \{s_1,\ldots,s_m\}\) we define the transition probabilities \[ p_{ij} := P(X_{n+1} = s_j \,|\, X_n = s_i)\,. \] Notice how technically \(p_{ij}\) depends on \(n\), however in many applications (including ours) this is undesired. We therefore assume \(p_{ij} = \textit{const. } \forall n\) and call this property (time-)homogeneity. Finally all these probabilities are collected in a matrix \(\mathbb{P} := (p_{ij})_{ij}\) called the transition matrix. So how does this allow to 'learn' probable token succession? Consider this training text (from the model president-3, trained on U.S. inauguration speeches):
TXT
My fellow citizens: I stand here today humbled
by the task before us, grateful for the trust
you have bestowed, mindful of the sacrifices
borne by our ancestors.
To build a Markov chain we need two things: Its state space \(S\) and its transition
matrix \(\mathbb{P}\). The state space naturally arises from tokenizing the training
text (go ahead and try it) and treating each unique token as a state i.e.
PSEUDO
token = split input text
for i = 1,...,N do:
// find correct transition
indexA = index of token[i-1] in S
indexB = index of token[i] in S
// update corresponding weight
weight[indexA][indexB] += 1
In our current example the first few weights are
Appearing both in training and generation, the
tokenizer acts as the translation layer between human and model. Despite its
importance the actual logic behind it is quite simple: Define a list of characters
that seperate token e.g.
The actual code introduces a
JS
Tokenizer.split = function(input) {
/* clean up input first */
const output = []; let token = '';
// pushes current token then resets it
const flush = (str) => {
if (!str.length) return;
output.push(str); token = '';
};
// builds up token and checks when to flush
for (const char of input) {...}
return output;
}
A small helper function
JS
for (const char of input) {
/* skip ignored characters */
const fullWord = separator.includes(char);
const specChar = special.includes(char);
// special chars form token on their own
if (fullWord || specChar) flush(token);
if (specChar) flush(char);
if (!fullWord && !specChar) token += char;
}
Next we'll cover the concept of a hashmap. A hashmap is a data structure that
enables extremely fast lookups, even when handling millions of elements. This will
later be crucial for keeping our Markov chain efficient. JavaScript already provides a
built-in version of this via
The core component of a hashmap is the so-called
hash function. For our purposes, a hash-function \(h: \mathcal{X} \to
\mathbb{N}\) is a mapping between \(\mathcal{X}\), the set of all possible strings
given some alphabet \(A\), and \(\mathbb{N}\) (in practice often restricted to
32-bits, for example). The value \(h(\chi)\) is called the hash code of string
\(\chi\). There are many different ways to define such a function. For instance
\[ h(c_1 \cdots c_n) := \sum_{k \leq n} \iota(c_k) \]
where \(\chi = c_1 \cdots c_n\) is represented by its characters and \(\iota: A \to
\mathbb{N}\) is an embedding of the given alphabet into \(\mathbb{N}\) (e.g. using
As our hash function, we'll implement the
JS
Hashmap._hashDJB2a = function(string) {
let hash = 5381;
for (let i = 0; i < string.length; i++) {
// multiply by 33 via left shift
hash = (hash << 5) + hash;
hash ^= string.charCodeAt(i);
}
// modulo 2^32 via unsigned right shift
return hash >>> 0;
}
While the implementation overall is fairly straightforward, two clever details are
worth noting. First, multiplying by \(33\) can be done efficiently by shifting left by
\(5\) bits (equivalent to multiplying by \(2^5 = 32\)) and then adding the original
value once. Second, taking the remainder modulo \(2^{32}\) simply forces the hash to
an unsigned 32-bit integer, which in JavaScript can be achieved with the unsigned
right shift
From this we can introduce the
JS
Hashmap._index = function(ID) {
let index = this._hash(ID);
// modulo 2^k via bitwise and
index &= (1 << this._power) - 1;
return index;
}
Hashmap.find = function(ID, match) {
/* make sure ID is valid first */
return this._map[this._index(ID)].find(match);
}
By keeping the size of
Proof. Let \((x)_2 = x_{n} \cdots x_1\) be an \(n\)-bit integer written in base
\(2\). Then taking the remainder modulo \(2^k\) for \(k \leq n\) simply gives the
lower \(k-1\) bits \(x_{k-1} \cdots x_1\). Equivalently
\[ x \wedge (2^k - 1) \,=\, x_{n} \cdots x_1 \wedge 0 \cdots 0 \underbrace{1 \cdots
1}_{k-1\,\text{times}} =\, x_{k-1} \cdots x_1 \]
where \(\wedge\) denotes the bitwise
Next we'll look at how elements are added to the hashmap. At first glance this seems
straightforward. Just hash each
JS
Hashmap.add = function(ID, el) {
this._map[this._index(ID)].push(el);
this._total += 1;
// trigger resize if N / M > alpha_max where
// N = total elements and M = total buckets
const load = this._total / (1 << this._power);
if (load > this.alpha) this._resize();
}
As discussed, a Markov chain is defined by its state space \(S\) and the transition probabilities \(p_{ij}\) between those states. One way to think about this is a directed graph \((V,E)\), where each vertex represents a state \((V = S)\) and each edge \(e_{ij} \in E\) connects the \(i\)-th state to the \(j\)-th one with the weight \(\omega_{ij} = (N-1)p_{ij}\). This perspective not only works perfectly as the foundation for our implementation, but also provides a clear and intuitive way to visualize the structure of the chain.
To implement this structure, we begin by introducing two classes:
JS
Vertex.addEdge = function({ targetID, weight = 1 }) {
/* update total weight */
// if edge already exists just update weight
let edge = this._findEdge(targetID);
if (edge) { edge.addWeight(weight); return; }
// else create new edge and add it to _edges
edge = new Edge({ targetID, weight });
this._edges.push(edge);
}
From here we can introduce the
JS
Chain.addVertex = function({ ID, edges = [] }) {
// if vertex already exists just update edges
let vertex = this._findVertex(ID);
const update = edge => vertex.addEdge(edge);
if (vertex) { edges.forEach(update); return; }
// else create new vertex and add it to _vertices
vertex = new Vertex({ ID });
const create = edge => vertex.addEdge(edge);
edges.forEach(create);
this._vertices.add(ID, vertex);
}
Notice how both, updating and creating an edge, simply mean to call
JS
Chain.nextState = function(/* --- */) {
/* handle undefined state */
/* define random integer */
const pivot = randInt(this._state.weight);
let threshold = 0;
// pick random edge according to their weights
// and advance current state to target vertex
for (const { weight, targetID } of this._state.edges) {
// update threshold with weight then check pivot
if ((threshold += weight) > pivot) {
this._state = this._findVertex(targetID);
return this._state;
}
}
}
Before moving on, we'll introduce a concept called context depth. So far, we've only discussed handling sequences of individual token. If you try setting the context depth in the demo to 'low' and generate some text, you'll notice that much of it turns out to be semantic nonsense. However this is expected, as a simple Markov chain has no ability to 'learn' anything beyond statistical token succession. Still, a simple generalization of what we've done so far can help mitigate this limitation.
Instead of treating each unique token as an individual state, we can group multiple
token together to form a single state. For example, where a simple training text like
JS
Chain._depthID = function(token, depth, i) {
return token.slice(i, i + depth);
}
Training a complete model therefore involves creating multiple Markov chains with
different context depths. To build up a chain, first tokenize the input text and then
iterate over the resulting token list. At each step, add a vertex of
JS
Chain.trainFrom = function(token, depth = this.depth) {
/* handle mismatched depth */
const train = (depth) => {
const ID = (i) => this._depthID(token, depth, i);
// iterate token list and create vertices accordingly
for (let i = 0; i < token.length - depth; i++) {
const edges = [{ targetID: ID(i + 1) }];
const vertex = { ID: ID(i), edges };
this.addVertex(vertex);
}
}
// create a separate markov chain for each depth
do train(depth); while (--depth > 0);
}
Notice that, similar as before, we don't need to worry about whether a given vertex
already exists. Updating and creating a vertex are both handled automatically by
Generating output token from a given list of seed token is done by first deriving an
initial state from the seed and then repeatably calling
JS
Chain.generate = function(
{ seed, length, depth = this.depth, /* --- */ }
) {
/* handle mismatched depth */
const output = [];
// 1. try to increase currently used context
...
// 2. if max context generate until length = 0
...
// 3. else add one new token to seed and retry
...
return output;
}
Since not every seed will necessarily lead to a state with the desired
JS
// 1. try to increase currently used context
let context = Math.min(seed.length, depth);
let vertex = {};
do {
const i = Math.max(seed.length - context, 0);
const ID = this._depthID(seed, context, i);
// if ID is empty get random vertex instead
if (!ID.length)
vertex = this._vertices.getRandom(depth);
else vertex = this.setState(ID);
}
// pre-decrement to not include depth = 0
while(!vertex && --context > 0);
Once the maximum
JS
// 2. if max context generate until length = 0
if (context == depth) {
while(length-- > 0 /* --- */) {
const vertex = this.nextState();
const token = vertex.ID.last();
output.push(token);
/* additional scaffolding */
}
}
Finally, if the initial context search fails to result in a state of the desired
JS
// 3. else add one new token to seed and retry
else {
// context + 1 since its decremented once extra
const vertex = this.nextState(context + 1);
const token = vertex.ID.last();
output.push(token);
// include new token and retry increasing depth
const newSeed = [...seed, token]; length--;
const subset = this.generate(
{ seed: newSeed, length, depth, /* --- */ });
output.push(...subset);
}
As already mentioned, all a Markov chain can really 'learn' is the statistical
succession of token. This means that if our training data contains the phrase
Consider a Markov chain defined by its state space \(S\) and the transition probabilities \(p_{ij}\). Our goal is to find a function \(\phi : S^{n+1} \to [0,1]\) such that 'mostly unique' sequences map close to \(1\), while 'mostly derived' ones map close to \(0\). For any state sequence \(s_0,\ldots,s_n \in S\) define \[ \phi_k := 1 - p_{k-1,k} = 1- \frac{\omega_{k-1,k}}{\sum_j \omega_{k-1,j}} \,.\] If the transition \(s_{k-1} \to s_k\) is unlikely, \(\phi_k\) will be close to \(1\). If it is common, \(\phi_k\) will be close to \(0\). The most natural way to combine the values \(\phi_1,\ldots,\phi_n\) into a single measure \(\phi(s_0,\ldots,s_n)\) would be to take their average. However, if the sequence consists of mostly derivative blocks, where unique transitions only happen between them, the average becomes highly skewed and fails to represent the overall derivativeness of the sequence accurately.
A better approach could be to use the median, that is, the middle value of all sorted
\(\phi_k\). Unlike the average, the median is resistant to skewing from a few highly
unusual transitions. However, it can also be too inert: When the sequence consists of
many short but mostly derivative blocks separated by unique transitions, the median
would likely not reflect this at all, since it would remain dominated by the frequent
common transitions within those blocks. A simple and practical solution to this is
simply combining both, average \(\overline{\phi}_n\) and median \(\widetilde{\phi}_n\)
as follows
\[ \begin{align} \phi(s_0,\ldots,s_n)\, &:=\, (1 - \alpha)\, \overline{\phi}_n \,+\,
\alpha\, \widetilde{\phi}_n \\[5px] &=\, \frac{1 - \alpha}{n} \sum_{k} \phi_k \,+\,
\frac{\alpha}{2}\, \Big[\phi_{\big(\lfloor \frac{n+1}{2} \rfloor\big)} +
\phi_{\big(\lceil \frac{n+1}{2} \rceil\big)}\Big] \end{align} \]
where \(\alpha \in [0,1]\) and \(\phi_{(k)}\) is the \(k\)-th sorted value. I chose
\(\alpha = 0.6\), giving the median a slightly stronger influence than the average.
Implementing the final formula for \(\phi(s_0,\ldots,s_n)\) is straightforward and
mainly involves computing each \(\phi_k\) i.e. the
JS
const uniqueness = (ID1, ID2) => {
const edges = this.getEdges(ID1);
// corresponds to omega_(k-1)_k
const target = (edge) => edge.targetID.equals(ID2);
const choiceWeight = edges.find(target)?.weight || 0;
// corresponds to sum_j omega_(k-1)_j
const sumWeight = (sum, edge) => sum + edge.weight;
const totalWeight = edges.reduce(sumWeight, 0);
return 1 - choiceWeight / totalWeight;
};
Working on this project has been not only a
If this has sparked your curiosity and you'd like to learn more about the actual core of LLMs, that is neural networks, I encourage you to check out my other project on the classification of handwritten letters. If you have any questions or suggestions, feel free to contact me. Finally, thank you so much for reading all the way to the end!