mlx-gpt2

mlx-gpt2

MLX框架实现GPT-2模型:从零开始的深度学习之旅

本项目展示了使用MLX框架从零实现GPT-2模型的完整过程。内容涵盖数据准备、词汇表创建和模型架构设计等核心步骤。该实现仅依赖MLX和NumPy库,可在MacBook上快速训练出能生成莎士比亚风格文本的模型。项目借鉴了Karpathy的GPT教程思路,并通过MLX框架重新实现,为深度学习爱好者提供了实践指南。

GPT-2MLX自注意力嵌入层训练循环Github开源项目

GPT-2 From Scratch in MLX

Train.py is ~200 lines of python code that define and train GPT-2 from scratch using mlx and numpy as the only dependencies. This readme will detail writing train.py from scratch. The model is trained on ~1 million characters of Shakespeare contained in input.txt, and it can be trained in around 10 minutes on a macbook to produce coherent Shakespeare-like text.
This tutorial follows Karpathy's GPT from scratch but in MLX.

Table of Contents

Preparing the data

Install mlx and run the following imports.

import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import mlx.utils as utils import numpy as np import math

The first step to training an LLM is collecting a large corpus of text data and then tokenizing it. Tokenization is the process of mapping text to integers, which can be fed into the LLM. Our training corpus for this model will be the works of Shakespeare concatenated into one file. This is roughly 1 million characters and looks like this:

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.
...

First, we read the file as a single long string into the text variable. Then we use the set() function to get all the unique characters in the text which will be our vocabulary. By printing vocab you can see all the characters in our vocabulary as one string, and we have a total of 65 characters which till be our tokens.

Creating the vocabulary

with open('input.txt', 'r', encoding='utf-8') as f: text = f.read() vocab = sorted(list(set(text))) vocab_size = len(vocab) print(''.join(vocab)) # !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz print(vocab_size) # 65

Production models will use tokenization algorithms like byte-pair encoding to generate a larger vocabulary of sub-word chunks. Since our focus today is on the architecture, we will continue with character-level tokenization. Next, we will map our vocabulary to integers known as token IDs. Then we can encode our text into tokens and decode them back to a string.

# Create mapping from vocab to integers itos = {i:c for i,c in enumerate(vocab)} # int to string stoi = {c:i for i,c in enumerate(vocab)} # string to int encode = lambda x: [stoi[c] for c in x] decode = lambda x: ''.join([itos[i] for i in x]) print(encode("hello world")) # [46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42] print(decode(encode("hello world"))) # hello world

We use the enumerate() function to iterate over all characters and their index in the vocabulary and create a dictionary itos which maps integers to characters and stoi which maps strings to integers. Then we use these mappings to create our encode and decode functions. Now we can encode the entire text and split training and validation data.

data = encode(text) split = int(0.9 * len(data)) train_data = data[:split] val_data = data[split:]

Currently, our training data is just a very long string of tokens. However, we are trying to train our model to predict the next token some given previous tokens. Therefore our dataset should be comprised of examples where the input is some string of tokens and the label is the correct next token. We need to define a model parameter called context length which is the maximum number of tokens used to predict the next token. Our training examples will be the length of our context length. Let's look at the first ctx_len+1 tokens.

ctx_len = 8 print(train_data[:ctx_len + 1]) # [18, 47, 56, 57, 58, 1, 15, 47, 58] # x: [18, 47, 56, 57, 58, 1, 15, 47] | y: 58

This is one training example where the input is "18, 47, 56, 57, 58, 1, 15, 47" and the desired output is "58". This is 8 tokens of context. However, we also want to train the model to predict the next token given only 7, 6, 5 … 0 tokens as context which is needed during generation. Therefore we also consider the 8 sub examples packed into this example:

ctx_len = 8 print(train_data[:ctx_len + 1]) # [18, 47, 56, 57, 58, 1, 15, 47, 58] # 8 sub examples # [18] --> 47 # [18, 47] --> 56 # [18, 47, 56] --> 57 # [18, 47, 56, 57] --> 58 # [18, 47, 56, 57, 58] --> 1 # [18, 47, 56, 57, 58, 1] --> 15 # [18, 47, 56, 57, 58, 1, 15] --> 47 # [18, 47, 56, 57, 58, 1, 15, 47] --> 58

Notice that the labels are simply the inputs shifted left.

print("inputs: ", train_data[:ctx_len]) print("labels: ", train_data[1:ctx_len+1]) # labels = inputs indexed 1 higher # inputs: [18, 47, 56, 57, 58, 1, 15, 47] # labels: [47, 56, 57, 58, 1, 15, 47, 58]

At index 0 the input is 18 and the label is 47. At index 1 the input is everything before and including index 1 which is [18, 47] and the label is 56, etc. Now that we understand that the labels are simply the input sequence indexed one higher we can build our datasets.

# Creating training and validation datasets ctx_len = 8 X_train = mx.array([train_data[i:i+ctx_len] for i in range(0, len(train_data) - ctx_len, ctx_len)]) y_train = mx.array([train_data[i+1:i+ctx_len+1] for i in range(0, len(train_data) - ctx_len, ctx_len)]) X_val = mx.array([val_data[i:i+ctx_len] for i in range(0, len(val_data) - ctx_len, ctx_len)]) y_val = mx.array([val_data[i+1:i+ctx_len+1] for i in range(0, len(val_data) - ctx_len, ctx_len)])

We loop through the data and take chunks of size ctx_len as the inputs (X) and then take the same chunks but at 1 higher index as the labels (y). Then we take these Python lists and create mlx array objects from them. The model internals will be written with mlx so we want our inputs to be mlx arrays.

One more thing. During training we don't want to feed the model one example at a time, we want to feed it multiple examples in parallel for efficiency. This group of examples is called our batch, and the number of examples in a group is our batch size. Thus we define a function to generate batches for training.

def get_batches(X, y, b_size, shuffle=True): if shuffle: ix = np.arange(X.shape[0]) np.random.shuffle(ix) ix = mx.array(ix) X = X[ix] y = y[ix] for i in range(0, X.shape[0], b_size): input = X[i:i+b_size] label = y[i:i+b_size] yield input, label

If shuffle=True, we shuffle the data by indexing it with a randomly shuffled index. Then we loop through our dataset and return batch-size chunks from input and label datasets. These chunks are known as mini-batches and are just stacked examples that we process in parallel. These mini-batches will be our input to the model during training.

Here's an example of a minibatch of 4 examples with context length 8. alt text This minibatch packs 32 next-token prediction problems. The model will predict the next token for each token in the input and the labels will be used to calculate the loss. Notice that the labels contain the next token for each index of the inputs.

You'll want to keep this picture in your mind because the shapes of these tensors will get hairy. For now, just remember that we will input a tensor of shape (batch_size, ctx_len) to the model.

Coding GPT-2

Let's look at the GPT-2 architecture to get an overview of what we are trying to implement. alt text Don't worry if this looks confusing. We will implement it step by step from bottom to top. Let's start by implementing the input embeddings.

Input Embeddings

The purpose of the input embedding layer is to map token IDs to vectors. Each token will be mapped to a vector which will be its representation as it is forwarded through the model. The vectors for each token will accumulate and exchange information as they pass through the model and eventually be used to predict the next token. These vectors are called embeddings.

The simplest way to map token IDs to vectors is through a lookup table. We create a matrix of size (vocab_size, n_emb) where each row is the embedding vector for the corresponding token. This matrix is known as the embedding weights.

alt text

The diagram shows an example embedding layer of size (65, 6). This means there are 65 tokens in the vocabulary and each one will be represented by a length 6 embedding vector. The inputted sequence will be used to index the embedding weights to get the vector corresponding to each token. Remember the minibatches we input into the model? Originally the minibatch is size (batch_size, ctx_len). After passing through the embedding layer it is size (batch_size, ctx_len, n_emb). Instead of each token being a single integer, each token is now a vector of length n_emb.

Let's define the embedding layer in code now.

n_emb = 6 # You can add these hyperparams at the top of your file class GPT(nn.Module): def __init__(self): super().__init__() self.wte = nn.Embedding(vocab_size, n_emb)

We will define a class to organize our implementation. We subclass nn.Module to take advantage of mlx's features. Then in the init function, we call the superclass constructor and initialize our token embedding layer called wte.

Positional Embeddings

Next up is the positional embeddings. The purpose of positional embeddings is to encode information about the position of each token in the sequence. This can be added to our input embeddings to get a complete representation of each token that contains information about the token's position in the sequence.

class GPT(nn.Module): def __init__(self): super().__init__() self.wte = nn.Embedding(vocab_size, n_emb) # token embeddings self.wpe = nn.Embedding(ctx_len, n_emb) # position embeddings

The position embeddings work the same as token embeddings, except instead of having a row for each token we have a row for each possible position index. This means our embedding weights will be of shape (ctx_len, n_emb). Now we implement the call function in our GPT class. This function will contain the forward pass of the model.

# Tensor shapes commented def __call__(self, x): B, T = x.shape # (B = batch_size, T = ctx_len) tok_emb = self.wte(x) # (B, T, n_emb) pos_emb = self.wpe(mx.arange(T)) # (T, n_emb) x = tok_emb + pos_emb # (B, T, n_emb)

First, we break out the dimensions of our input into variables B and T for easy handling. In sequence modeling contexts B and T are usually used as shorthand for "batch" and "time" dimensions. In this case, the "time" dimension of our sequence is the context length.

Next, we calculate token and position embeddings. Notice that for the position embeddings, our input is mx.arange(T). This will output an array of consecutive integers from 0 to T-1 which is exactly what we want because those are the positions we want to embed. After passing that through the embedding layer we will have a tensor of shape (T, n_emb) because the embedding layer plucks out the n_emb length vector for each of the T positions. Note that even though pos_emb is not the same shape as tok_emb we can add the two because mlx will broadcast, or replicate pos_emb across the batch dimension to allow elementwise addition. Finally, we perform the addition to get the new representations of the tokens with positional information.

Self-Attention

So far the representation vectors for each token have been calculated independently. They have not had the opportunity to exchange any information. This is intuitively bad in language modeling because the meaning and usage of words depend on the surrounding context. Self-attention is how we incorporate information from previous tokens into a given token. First, let's consider a naive approach. What if we simply represented each token as the average of its representation vector and the vectors of all the tokens before it? This achieves our goal of packing information from previous tokens into the representation for a given token. Here's what it would look like.

alt text

But self-attention doesn't involve writing a for-loop. The key insight is we can achieve this previous token averaging with matrix multiplication!

alt text

By multiplying our input sequence on the left by a special matrix we get the desired result. This matrix is known as the attention weights. Notice that each row of the attention weight matrix specificies "how much" of each other token goes into the representation for any given token. For example in row two, we have [0.5, 0.5, 0, 0]. This means that row two of the result will be 0.5*token1 + 0.5*token2 + 0*token3 + 0*token4, or the average of token1 and token2. Note that the attention weights are a lower-triangular matrix (zeros in upper right entries). This ensures that future tokens will not be included in the representation of a given token. This ensures that tokens can only communicate with the previous tokens because during generation the model will only have access to previous tokens.

Let's look at how we can construct the attention weight matrix.

alt text

Notice that if we create an array of zeros with -inf in the upper right entries and then perform row-wise softmax we get the desired attention weights. A good exercise is to step through the softmax calculation for a row to see how this works. The takeaway is that we can take some array of size (ctx_len, ctx_len) and softmax each row to get attention weights that sum to one.

Now we can leave the realm of naive self-attention. Instead of simply averaging previous tokens, we use arbitrary weighted sums over previous tokens. Notice what happens when we do row-wise softmax of an arbitrary matrix.

alt text

We still get weights that sum to one on each row. During training, we can learn the numbers in the matrix on the left which will specify how much each token goes into the representation for another token. This is how tokens pay "attention" to each other. But we still haven't understood where this matrix on the left came from. These pre-softmax attention weights are calculated from the tokens themselves, but indirectly through three linear projections.

Keys, Queries, and Values

alt text

Each token in our sequence emits 3 new vectors. These vectors are called keys, queries, and values. We use the dot product of the query vector of one token and the key vector of another token to quantify the "affinity" those two tokens have. We want to calculate the pairwise affinities of each token with every other token, therefore we multiply the query vector (4x3) with the key vector transposed (3x4) to get the raw attention weights (4x4). Due to the way matrix multiplication works the (i,j) entry in the raw attention weights will be the query of token i dot the key of token j or the "affinity" between the two. Thus we have calculated interactions between every token. However, we don't want past tokens interacting with future tokens so we apply a mask of -inf to the upper right entries to ensure they will zero out after softmax. Then we perform row-wise softmax to get the final attention weights. Instead of multiplying these weights directly with the input, we multiply them with the value projection. This results in the new representations.

Now that we understand attention conceptually, let's implement it.

class Attention(nn.Module): def

编辑推荐精选

Keevx

Keevx

AI数字人视频创作平台

Keevx 一款开箱即用的AI数字人视频创作平台,广泛适用于电商广告、企业培训与社媒宣传,让全球企业与个人创作者无需拍摄剪辑,就能快速生成多语言、高质量的专业视频。

即梦AI

即梦AI

一站式AI创作平台

提供 AI 驱动的图片、视频生成及数字人等功能,助力创意创作

扣子-AI办公

扣子-AI办公

AI办公助手,复杂任务高效处理

AI办公助手,复杂任务高效处理。办公效率低?扣子空间AI助手支持播客生成、PPT制作、网页开发及报告写作,覆盖科研、商业、舆情等领域的专家Agent 7x24小时响应,生活工作无缝切换,提升50%效率!

TRAE编程

TRAE编程

AI辅助编程,代码自动修复

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
蛙蛙写作

蛙蛙写作

AI小说写作助手,一站式润色、改写、扩写

蛙蛙写作—国内先进的AI写作平台,涵盖小说、学术、社交媒体等多场景。提供续写、改写、润色等功能,助力创作者高效优化写作流程。界面简洁,功能全面,适合各类写作者提升内容品质和工作效率。

AI辅助写作AI工具蛙蛙写作AI写作工具学术助手办公助手营销助手AI助手
问小白

问小白

全能AI智能助手,随时解答生活与工作的多样问题

问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。

热门AI助手AI对话AI工具聊天机器人
Transly

Transly

实时语音翻译/同声传译工具

Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

下拉加载更多