Unlocking multi-threaded parallel inference on PyTorch models
PyTorch
Python
Free-Threading
No-GIL
LLM
GPT2
This post examines multi-threaded parallel inference on PyTorch models using the new No-GIL, free-threaded version of Python. Using a simple 124M parameter GPT2 model that we train from scratch, we explore the novel new territory unlocked by free-threaded Python: parallel PyTorch model inference, where multiple threads, unimpeded by the Python GIL, attempt to generate text from a transformer-based model in parallel.
Toggle the little radio button in the navigation bar at the top to switch between native Light Mode and Dark Mode themes.
Introduction
Python 3.13, released in October 2024, is the first version of Python to introduce support for a “no-GIL” free-threaded mode, per PEP-703 Making the Global Interpreter Lock Optional in CPython, unlocking the ability for multiple Python threads to run simultaneously.
This allows, for the first time since the language’s inception in December 1989, a single Python process to saturate all CPU cores in parallel with pure Python code (i.e. not farming out to extension modules written in C, C++, or, more recently, Rust).
A handful of the motivations captured in that PEP opine on how the GIL impedes Python AI workflows, particularly as it relates to GPU programming.
This blog post explores what can be done with PyTorch now with the new free-threaded version of Python, specifically focusing on run-time inference on transformer-based generative models.
Note
I didn’t focus on how training PyTorch models might look in the new free-threaded Python world for a couple of reasons.
Primarily, training is a lot more complex if you’re involving multiple nodes—as gradients need to be carefully synchronized at critical points, for example—and is well outside the scope of a simple blog post.
Additionally, there’s already a huge body of existing work tackling multi-node training in PyTorch by way of the Distributed Data Parallel, multiprocessing-based facilities exposed by torch.distributed.
Whereas, on the flip side, no one has really explored what parallel inference might look like in a single-threaded Python because the GIL has prevented that from even being an option until now.
Getting Started
All of this work was done on Linux (Ubuntu 22.04) with Python 3.13t, PyTorch 2.6, and CUDA 12.6.
Full source code is provided to everything captured in this post. It is worth noting that in a few cases, I am rolling my own solutions for things that have existing solutions in the broader Python ecosystem. For example, in the tail end of the post, I leverage a multi-threaded asyncio-based HTTP server I wrote instead of using existing solutions like FastAPI.
The reason for this is that, as free-threaded Python is still in its infancy, a lot of packages do not work with it yet, especially those that rely on Cython, C or C++ code, or Rust.
In fact, Rust dependencies are particularly problematic due to the proliferation of Python projects leveraging PyO3 (Rust bindings for Python), especially prominent projects such as TikToken and Pydantic, upon which a lot of the Python AI ecosystem is built. PyO3 only recently grew support for free-threaded Python in their 0.23.3 release, which came out in December, 2024, and many dependent projects are yet to update to it.
Thus, this post and its supporting code should not be considered the state of the art for production deployments—its primary goal is exploratory in nature, and minimizing the number of moving pieces in the stack helps achieve this goal.
Environment Setup Details
Environments
It is fiddly getting the environments set up in support of this post. Again, this is due to the infancy of free-threaded Python. So I apologize in advance for how long this environment setup section is.
I reference two conda environments in this post: a Python 3.13 free-threaded one named py313t, and a normal, not-free-threaded Python 3.13 one named py313.
The primary motivation behind the second py313 environment is it allows us to install Jupyter Lab, which, at the time of writing, still isn’t compatible with a Python free-threaded installation. However, we can still register a free-threaded Python kernel with Jupyter, which is all we really care about when running the code in this post in a free-threaded environment.
Details on creating the conda environments follow.
Free-Threaded 3.13 Env (py313t)
I use conda to create the Python 3.13 free-threaded environment plus initial dependencies, activate it, then install the remaining dependencies via pip, as follows:
nodejs is required for the UI component we’ll introduce later. regex, rust, and setuptools_rust are needed for tiktoken, described next. Finally, numpy is for torch, which we install later, too.
TikToken
TikToken is a fast BPE tokenizer from OpenAI that is used extensively in the emerging Python LLM landscape. At the time of writing, the latest TikToken release was 0.8.0, which was built against PyO3 0.22.2, which isn’t compatible with free-threaded Python.
Thankfully, it was trivial to get a local installation of tiktoken working by cloning the Github repo, bumping the PyO3 version in Cargo.toml, then rebuilding and installing.
Note
This is a perfect example of the type of fiddling around I wanted to avoid by not depending on any external packages other than the bare necessities, such as PyTorch. I made an exception for tiktoken because a) it’s arguably an equally-important part of the LLM stack as torch, and b) it thankfully wasn’t too difficult getting a compatible version of tiktoken installed locally.
Clone the tiktoken git repo and cd into it as follows:
After this, you should be able to import the tiktoken module in Python:
% cd ..% python -Xgil=0Python 3.13.1 experimental free-threading build |packaged by conda-forge |(main, Jan 13 2025, 09:59:40)[GCC 13.3.0] on linuxType"help", "copyright", "credits" or "license" for more information.>>> import tiktoken>>>
Torch
Install PyTorch 2.6 via pip with the conda py313t environment active:
If you have trouble installing PyTorch, consult their Getting Started guide.
You can verify torch installed correctly as follows:
% python -Xgil=0Python 3.13.1 experimental free-threading build |packaged by conda-forge |(main, Jan 13 2025, 09:59:40)[GCC 13.3.0] on linuxType"help", "copyright", "credits" or "license" for more information.>>> import torch>>> torch.cuda.is_available()True
IPython Kernel
Installing IPython Kernel allows us to use our free-threaded Python installation via the Jupyter Lab instance we install in the py313 environment.
This will install a kernel configuration file, kernel.json, which we need to tweak by adding the -Xgil=0 startup flag to the Python interpreter:
% cd ~/.local/jupyter/share/kernels/py313t% cp kernel.json kernel.json.orig% vi kernel.json# Edit kernel.json to make it look like the diff below.% diff -u kernel.json.orig kernel.json
(Hidden for now as we removed the immediate need for datrie by fixing the HTTP server routing logic.)
Datrie and Cython
datrie is a Python library that provides a trie (or digital search tree) data structure by way of the libdatrie C library. The Python datrie library isn’t strictly necessary to run parallelopedia.gpt2, but other components rely on it, so it’s handy to get installed now, if possible.
It relies upon Cython, and thus, for now, you need to install a free-threaded compatible version of Cython first, as follows:
If everything goes well, you should see something like this when you launch Python and import datrie:
% pythonPython 3.13.1 experimental free-threading build |packaged by conda-forge |(main, Jan 13 2025, 09:59:40)[GCC 13.3.0] on linuxType"help", "copyright", "credits" or "license" for more information.>>> import datrie>>>
Normal 3.13 Env (py313)
The second py313 environment is almost identical to py313t, except it is not a python-freethreading installation, and, additionally, we install Jupyter Lab. We can install tiktoken directly via pip, so we don’t need the supporting Rust cruft.
All of the code in this article is available in the Parallelopedia repository on Github. The code we’ll be focusing on in this post lives in the parallelopedia.gpt2 module.
There is also a web user interface component named Parallelopedia-UI, which we will use later in the post.
The code and command examples in this post will assume you’ve added the src directory to your PYTHONPATH, the bin directory to your PATH, and set the PARALLELOPEDIA_ROOT environment variable to the root of the repository. You can do this as follows:
cd parallelopediaexportPYTHONPATH=$(pwd)/src:$PYTHONPATHexportPATH=$(pwd)/bin:$PATHexportPARALLELOPEDIA_ROOT=$(pwd)cd ..cd parallelopedia-uiexportPARALLELOPEDIA_UI=$(pwd)
It is recommended that you add these to your shell. For me, using zsh, I use the following:
You can perform a quick sanity check that things are working as follows:
% python -Xgil=0 -m parallelopedia.http.server --helpusage: server.py [-h] [--ip IP] [--port PORT] [--debug] [--log-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}] [--threads THREADS] [--protocol-class PROTOCOL_CLASS] [--app-classes APP_CLASSES [APP_CLASSES ...]] [--listen-backlog LISTEN_BACKLOG]Run the HTTP server.options:-h,--help show this help message and exit--ip IP IP address to bind the server to.--port PORT Port number to bind the server to.--debug Enable debug mode for asyncio.--log-level{DEBUG,INFO,WARNING,ERROR,CRITICAL}Set the logging level.--threads THREADS Number of threads to use.--protocol-class PROTOCOL_CLASSThe protocol class to use for the server.--app-classes APP_CLASSES [APP_CLASSES ...]Space-separated list of HTTP application classes.--listen-backlog LISTEN_BACKLOGThe listen backlog for the server.
PyTorch and LLM Crash Course
My involvement with PyTorch and Large Language Models (LLMs) started around late November last year, 2024. Going in, I knew nothing about PyTorch, nor deep neural networks, nor LLMs—other than having enjoyed using LLMs thoroughly the past couple of years. I had never trained an AI model of any kind. I did have a bit of NumPy and data science exposure up my sleeve, plus general familiarity with Python.
Thanks to Andrej Karpathy’s phenomenal YouTube series on deep neural networks and LLMs titled Neural Networks: From Zero to Hero, over the course of about 3 weeks or so I went from zero to… well I wouldn’t necessarily say hero—perhaps zero to not-completley-clueless is more apropos.
Andrej’s content is a fantastic resource to learn everything you need to know to understand how modern LLMs work from the ground-up. It’s not a short series—there are 19 hours, 21 minutes and two seconds of content across ten videos—and you’ll probably spend double that if you really want to properly absorb the content.
None of the work presented in this post would have been possible had I not invested the time in Andrej’s series. If you’re reading this Andrej, thanks, and keep up the brilliant work!
Training GPT-2 (124M) Locally
Equipped with my new knowledge about LLMs, PyTorch, and, thanks to Andrej’s final video in the series titled Let’s reproduce GPT-2 (124M) and the accompanying build-nanogpt Github repo, I was able to train a local GPT-2 model via PyTorch, from scratch, using the edu_fineweb10B dataset.
I only had to make one change in order to run locally: Use 8 for micro batch size instead of 64. With that change in place, I was able to train GPT-2 from scratch as follows:
% conda activate py313% git clone gh:tpn/build-nanogpt% cd build-nanogpt# Download the fineweb dataset.% python fineweb.py# Train!% torchrun --standalone--nproc_per_node=4 train_gpt2.py
This was run on an NVIDIA DGX workstation from 2017, which has an Intel Xeon(R) CPU E5-2698 v4 @ 2.20GHz (20 cores, 40 threads), and four Tesla V100-DGXS-32GB GPUs.
Training in parallel across all four GPUs yielded around 36,000 tokens/sec, with an average time of about 14.5 seconds per loop iteration. Training took about 3 days and 5 hours for 19,072 steps. All four GPUs were pegged close to their 300W power limit for those three days.
Note
Amusingly, well after the fact, I decided to see what kind of training performance I’d get on my Windows 11 gaming box (via WSL2 and Ubuntu 22.04), which has an AMD Ryzen 9 7950X3D (16 cores, 32 threads) and NVIDIA RTX 4090. Training via python train_gpt2.py (torchrun wasn’t needed as I wasn’t using multiple GPUs) yielded about 45,000 tokens/sec, which is a nice bump, but what was most impressive was the reduction to the loop iteration duration, which averaged out to about 180ms!
So, I could have completed the same training process in about an hour or so, at a vastly reduced impact on my electricity bill that month :-)
Once training completes, a log/model_19072.pt file is produced, which is the checkpoint of the model at that final step, obtained via a call to torch.save(). The model has 124M parameters—which is tiny by modern standards—and is just under 500MB on disk.
You can download that very model I trained via the HuggingFace dataset I set up here: model_19072.pt. Once downloaded, place the file in $PARALLELOPEDIA_ROOT/data; alternatively, if you run the Jupyter Notebook below, it’ll automatically download the model from HuggingFace on first run.
PyTorch GPT-2 Implementation
Let’s introduce the first version of the Python code we’re going to use. Again, all of this has been made possible thanks to Andrej Karpathy’s work with his YouTube series and build-nanogpt repo, so any and all code you see in this post can typically be traced back to something equivalent that appears in train_gpt2.py. None of this code would have made any sense to me a month or two ago—but I can promise you that if you devote sufficient time to watching and understanding the entire series, you’ll have a comprehensive understanding of all of the pieces present in this post.
You can follow along in a Jupyter Lab notebook if you activate the py313 environment and launch jupyter lab. If you correctly registered your py313t kernel per the instructions earlier, you should see an option when creating a new notebook to use the py313t Python kernel, which will be the free-threaded version. On the right of this page you should see all of the notebooks referenced.
Initial Implementation
The code below roughly corresponds to my first version of the code in the commit 3ed4fe6: Add gpt2.py, with some formatting and style tweaks to ensure the code is viewable on mobile devices without requiring horizontal scrolling.
We’ll revise this code later in the post, but for now, it’s a good starting point to get a feel for how we can use PyTorch to load a GPT-2 model checkpoint, tokenize some input text, and generate some output text.
Code
# gpt2_v1.ipynb# ===================================================================# Imports# ===================================================================import dataclassesfrom dataclasses import dataclassimport loggingimport osfrom os.path import join, existsimport subprocessimport sysimport textwrapfrom textwrap import wrapimport timeimport tiktokenimport torchimport torch.nn as nnfrom torch.nn import functional as F# ===================================================================# Helper Timer Class# ===================================================================class ElapsedTimer:""" Context manager and reusable timer to measure elapsed time. Example: timer = elapsed_timer() with timer: do_something() print(f'Elapsed: {timer.elapsed:.3f}') # Re-enterable: with timer: do_something_else() print(f'Elapsed: {timer.elapsed:.3f}') """def__init__(self):self.start =Noneself._elapsed =Nonedef__enter__(self):self.start = time.perf_counter()returnselfdef__exit__(self, exc_type, exc_value, traceback):self._elapsed = time.perf_counter() -self.start@propertydef elapsed(self):""" Return the elapsed time for the most recent context. """ifself._elapsed isNone:raiseValueError("Timer has not been used in a context yet.")returnself._elapsed# ===================================================================# Globals# ===================================================================LOG_LEVEL ='DEBUG'PARALLELOPEDIA_ROOT = os.environ['PARALLELOPEDIA_ROOT']PARALLELOPEDIA_DATA_DIR = join(PARALLELOPEDIA_ROOT, 'data')MODEL_CHECKPOINT = join( PARALLELOPEDIA_DATA_DIR,'model_19072.pt',)MODEL_DOWNLOAD_URL = ("https://huggingface.co/datasets/trentnelson/""parallelopedia-data-gpt2/resolve/main/model_19072.pt")# Download the model from huggingface if necessary.os.makedirs(PARALLELOPEDIA_DATA_DIR, exist_ok=True)ifnot exists(MODEL_CHECKPOINT):print(f'Downloading {MODEL_DOWNLOAD_URL} via wget ''this might take a while...') args = ["wget","--quiet", MODEL_DOWNLOAD_URL,"-P", PARALLELOPEDIA_DATA_DIR, ] timer = ElapsedTimer()with timer: subprocess.run(args, check=True)print(f'Downloaded model in {timer.elapsed:.3f} seconds.')assert exists(MODEL_CHECKPOINT), "Missing checkpoint."# ===================================================================# Logging# ===================================================================# N.B. We redirect logs to sys.stdout in order for Quarto to pick# them up and include them in rendering the output.logging.basicConfig( level=getattr(logging, LOG_LEVEL),format='%(asctime)s - %(levelname)s - %(message)s', stream=sys.stdout)# ===================================================================# Setup# ===================================================================# Use bfloat16 for matmul precision where possible.torch.set_float32_matmul_precision('high')# ===================================================================# GPT2 PyTorch Model Components# ===================================================================# Now define the classes making up our GPT2 implementation.# These map directly to the components introduced by the# now-seminal 2017 "Attention Is All You Need" paper.class CausalSelfAttention(nn.Module):""" Causal self-attention for the GPT2 model. """def__init__(self, config):super().__init__()assert config.n_embd % config.n_head ==0# Key, query, value projections for all heads, but in a batch.self.c_attn = nn.Linear(config.n_embd, 3* config.n_embd)# Output projection.self.c_proj = nn.Linear(config.n_embd, config.n_embd)self.c_proj.NANOGPT_SCALE_INIT =1# Regularization.self.n_head = config.n_headself.n_embd = config.n_embddef forward(self, x):# Batch size, sequence length, embedding dimensionality. B, T, C = (x.size())# Calculate query, key, values for all heads in# batch and move head forward to be the batch dim.## N.B. nh is "number of heads", hs is "head size",# and C (number of channels) is nh * hs.# E.g. in GPT-2 (124M), n_head=12, hs=64, so# nh*hs=C=768 channels in the Transformer. qkv =self.c_attn(x) q, k, v = qkv.split(self.n_embd, dim=2) head_dim = C //self.n_head# (B, nh, T, hs) k = k.view(B, T, self.n_head, head_dim).transpose(1, 2)# (B, nh, T, hs) q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)# (B, nh, T, hs) v = v.view(B, T, self.n_head, head_dim).transpose(1, 2)# Flash attention. y = F.scaled_dot_product_attention(q, k, v, is_causal=True)# Re-assemble all head outputs side by side. y = (y.transpose(1, 2).contiguous().view(B, T, C))# Output projection. y =self.c_proj(y)return yclass MLP(nn.Module):""" Multi-layer perceptron for the GPT2 model. """def__init__(self, config):super().__init__()self.c_fc = nn.Linear(config.n_embd, 4* config.n_embd)self.gelu = nn.GELU(approximate='tanh')self.c_proj = nn.Linear(4* config.n_embd, config.n_embd)self.c_proj.NANOGPT_SCALE_INIT =1def forward(self, x): x =self.c_fc(x) x =self.gelu(x) x =self.c_proj(x)return xclass Block(nn.Module):""" Transformer block for the GPT2 model. """def__init__(self, config):super().__init__()self.ln_1 = nn.LayerNorm(config.n_embd)self.attn = CausalSelfAttention(config)self.ln_2 = nn.LayerNorm(config.n_embd)self.mlp = MLP(config)def forward(self, x): x = x +self.attn(self.ln_1(x)) x = x +self.mlp(self.ln_2(x))return x# ===================================================================# GPT2 Supporting Classes# ===================================================================# N.B. These differ slightly from Andrej's classes in# `train_gpt2.py`. `GPTCheckpoint` is a helper# class I wrote that has no analog in the former.@dataclassclass GPTConfig:""" Configuration class for GPT model. Attributes: block_size (int): Maximum sequence length. vocab_size (int): Number of tokens. GPT2 from huggingface has a vocab size of 50257, which includes 50,000 BPE merges, 256 byte tokens, and 1 <|endoftext|> token. However, Andrej Karpathy's `build-nanogpt/train_gpt2.py` uses a vocab size of 50304. I vaguely recall the explanation for this discrepancy as a local optimization to yield better alignment sizes, but I'm not 100% certain. The local GPT2 training that we did on edu_fineweb10b used 50304, so we will use that here. n_layer (int): Number of layers. n_head (int): Number of attention heads. n_embd (int): Embedding dimension. """ block_size: int=1024 vocab_size: int=50304 n_layer: int=12 n_head: int=12 n_embd: int=768# ===================================================================# GPT2 Model Implementation# ===================================================================class GPT(nn.Module):def__init__(self, config, device):super().__init__()self.config = configself.device = deviceself.manual_seed =42self.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_size, config.n_embd), wpe=nn.Embedding(config.block_size, config.n_embd), h=nn.ModuleList( [Block(config) for _ inrange(config.n_layer)] ), ln_f=nn.LayerNorm(config.n_embd), ) )self.lm_head = nn.Linear( config.n_embd, config.vocab_size, bias=False )self.transformer.wte.weight =self.lm_head.weightself.apply(self._init_weights)def _init_weights(self, module):ifisinstance(module, nn.Linear): std =0.02ifhasattr(module, "NANOGPT_SCALE_INIT"): std *= (2*self.config.n_layer) **-0.5 torch.nn.init.normal_(module.weight, mean=0.0, std=std)if module.bias isnotNone: torch.nn.init.zeros_(module.bias)elifisinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)def forward(self, idx, targets=None):""" Forward pass of the GPT model. Args: idx (torch.Tensor): Supplies the input tensor of shape (B, T). targets (torch.Tensor): Optionally supplies the target tensor of shape (B, T) for computing the loss. """ (B, T) = idx.size()# Forward the token and position embeddings.# Shape (T) pos = torch.arange(0, T, dtype=torch.long, device=idx.device)# Position embeddings of shape (T, n_embd). pos_emb =self.transformer.wpe(pos)# Token embeddings of shape (B, T, n_embd). tok_emb =self.transformer.wte(idx) x = tok_emb + pos_emb# Forward the blocks of the transformer.for block inself.transformer.h: x = block(x)# Forward the final layernorm and the classifier. x =self.transformer.ln_f(x)# (B, T, vocab_size) logits =self.lm_head(x) loss =Noneif targets isnotNone: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1) )return (logits, loss)@classmethoddef from_local_pretrained( cls, model_path: str, map_location: str="cuda" ):""" Load a model from a local checkpoint. N.B. This is a new method based off GPT.from_pretrained in Andrej Karpathy's train_gpt2.py. Args: cls (type): Supplies the class type. model_path (str): Supplies the path to the model checkpoint. map_location (str): Supplies the device to which the model will be mapped. """with torch.serialization.safe_globals([GPTConfig]): checkpoint = torch.load( model_path, map_location=map_location, ) config = checkpoint["config"] config = GPTConfig(**checkpoint["config"]) model = cls(config, device=map_location) model.load_state_dict(checkpoint["model"]) model.eval() msg = (f"Loaded model from step {checkpoint['step']}, "f"val_loss {checkpoint['val_loss']}" ) logging.info(msg)return modeldef generate(self, text: str, max_length: int=1024, top_k: int=50, seed: int=None, ) ->str:""" Generate text from the model. N.B. This is a new method based off the generation code present in Andrej Karpathy's train_gpt2.py. Args: text (str): Supplies the prompt. max_length (int): Supplies the maximum total length, including prompt. top_k (int): Supplies the number of tokens to consider at each generation step. seed (int): Optionally supplies the manual seed to use for the generator. If None, the model's manual seed will be used. Returns: str: The generated text (including the initial prompt). """self.eval() device =self.device# Obtain our GPT2 tokenizer, and resolve the stop token. enc = tiktoken.get_encoding("gpt2") stop_string ='<|endoftext|>' stop_token = enc.n_vocab -1 actual = enc.decode([stop_token])assert actual == stop_string, (f"expected {stop_string}, got {actual}" )# Encode the prompt. tokens = enc.encode(text) x = torch.tensor( tokens, dtype=torch.long, device=device ).unsqueeze(0)# Create a random generator for reproducibility.if seed isNone: seed =self.manual_seed sample_rng = torch.Generator(device=device) sample_rng.manual_seed(seed)# Generate tokens up to our max length, or until we hit the# stop token. start = time.perf_counter() count =0while x.size(1) < max_length: count +=1with torch.no_grad():# Forward pass, ignoring the returned loss. (logits, _) =self(x)# Take the logits at the last time-step (shape:# (1, vocab_size)). logits = logits[:, -1, :]# Convert to probabilities. probs = F.softmax(logits, dim=-1)# Top-k sampling. topk_probs, topk_indices = torch.topk( probs, k=top_k, dim=-1 )# Sample the next token. next_idx = torch.multinomial( topk_probs, num_samples=1, generator=sample_rng ) next_token = torch.gather(topk_indices, -1, next_idx)# If the next token is the stop token, we're done.if next_token.item() == stop_token:break# Otherwise, append the token to the current sequence# and continue generation. x = torch.cat((x, next_token), dim=1) end = time.perf_counter() elapsed = end - start tokens_per_sec =float(count) / elapsed msg = (f'Generated {count} tokens in {elapsed:.2f} seconds 'f'({tokens_per_sec:.2f} tokens/sec)' ) logging.debug(msg)# Decode the output tokens and return the generated text,# including the initial prompt. output_tokens = x[0].tolist()return enc.decode(output_tokens)
Loading the Model
With the code above executed in a preceding Jupyter Notebook cell, we can load the model as follows:
model = GPT.from_local_pretrained( MODEL_CHECKPOINT, map_location='cuda',)model.to('cuda')
2025-02-09 15:26:39,136 - INFO - Loaded model from step 19072, val_loss 3.0519702434539795
Once we’ve got a model instance, text can be generated by simply calling the model’s generate() function with a prompt, and, optionally, some additional parameters like seed and max length. This is also referred to as inference—the two terms are interchangeable, and mean the same thing: the act of providing some input tokens—your encoded prompt—to your trained model, and having it generate tokens in response.
Note that this isn’t a chat or instruction model, nor has it been fine-tuned to remotely resemble something actually usable. So you can’t ask it questions or have it write code for you.
What you can do, though, is provide it with a half written sentence, and then laugh at the ridiculous content it generates in response. Although note that its syntax is pretty good—the model has clearly learned enough during training about how the English language is structured, which words make sense when placed together, that sort of thing. It just has no clue about underlying semantics.
prompt ="Albert Einstein's Theory of Relativity stated that"result = model.generate(prompt, seed=42)print('\n'+ textwrap.fill(result, width=105))
2025-02-09 15:26:40,464 - DEBUG - Generated 79 tokens in 0.81 seconds (98.10 tokens/sec)
Albert Einstein's Theory of Relativity stated that the speed of light was approximately 10 000 of
parsecs, whereas quantum physicists have suggested that, as we move further into the universe, the
universe might grow older. The new experiment, conducted by researchers at the University of New Jersey,
New York, and the University of California, Berkeley shows that photons travelling at the speed of light
will be around 30 to 65 kilometres per second.
prompt ="Albert Einstein's Theory of Relativity stated that"result = model.generate(prompt, seed=42)print('\n'+ textwrap.fill(result, width=58))
2025-02-09 15:26:41,014 - DEBUG - Generated 79 tokens in 0.54 seconds (145.49 tokens/sec)
Albert Einstein's Theory of Relativity stated that the
speed of light was approximately 10 000 of parsecs,
whereas quantum physicists have suggested that, as we move
further into the universe, the universe might grow older.
The new experiment, conducted by researchers at the
University of New Jersey, New York, and the University of
California, Berkeley shows that photons travelling at the
speed of light will be around 30 to 65 kilometres per
second.
prompt ="Albert Einstein's Theory of Relativity stated that"result = model.generate(prompt, seed=42)print('\n'+ textwrap.fill(result, width=45))
2025-02-09 15:26:41,565 - DEBUG - Generated 79 tokens in 0.54 seconds (145.05 tokens/sec)
Albert Einstein's Theory of Relativity stated
that the speed of light was approximately 10
000 of parsecs, whereas quantum physicists
have suggested that, as we move further into
the universe, the universe might grow older.
The new experiment, conducted by researchers
at the University of New Jersey, New York,
and the University of California, Berkeley
shows that photons travelling at the speed of
light will be around 30 to 65 kilometres per
second.
Now, it’s worth noting at this point that a 124 million parameter GPT2 model, trained from scratch for 19,072 iterations on the edu_fineweb10b data set, with a final loss score of 3.052, is, quite frankly, hot garbage :-)
Do not be expecting output from this model to rival anything close to what you’d expect from a contemporary LLM. In fact, we can’t even rely on it to even remotely generate content that is factual in nature. It spews hot probabilisitic garbage that mostly conforms to the structure of the English language.
But at least it’s our hot garbage that we trained from nothing, and it’s all we need to start playing around with generating text in parallel.
Hot Garbage Example
Provide a slightly different seed to the prompt above and you’ll precisely see some hot garbage generation in action. In the following example, with an identical prompt as the one prior, it just streams nonsense until hitting its max token limit, nary a stop token in sight.
result = model.generate(prompt, seed=20190903)print('\n'+ textwrap.fill(result, width=100))
2025-02-09 15:26:57,518 - DEBUG - Generated 1015 tokens in 15.95 seconds (63.65 tokens/sec)
Albert Einstein's Theory of Relativity stated that the speed of light is the same as it is in two
places, which means that a given speed can either be described by two different speed equations
directly or they may be both equations. It is then assumed that the speed of light is the speed of
the universe or the universe's existence relative to Earth. In relativity, a measure of the speed of
light is the absolute speed of the light. As long as the speed of light is less than its speed in
two different places, the absolute speed can be calculated. For example, the absolute speed is
1/2990000000 (2,299,792,458) km/hr with an absolute speed about 10 times as fast as it is in two
different places. Now we can use the following equation to describe the speed of light: E = C/C2 The
speed of light, as a function of C, is a constant. By Einstein's definition of relativity, the speed
of light is a constant. This is because light travels at its maximum speed along the direction (if
it's travelling above the speed of light, the point where light must be observed is called
"aperture" of the speed of light). The speed of light is about half as fast as the speed of light
because the speed of light has a smaller varying velocity for each direction of radiation. The speed
of light, as a function of C, is a constant. The speed of a wave is the constant measured along the
direction of the wave relative to its location in space. E = C/C2 where E is the speed of light
along the direction of the wave. Because the speed of the wave is the speed of the particle in the
wave, and c the speed of the particle, E's is also given by the speed of light. For example, a light
particle is moving from its place of greatest velocity to its location of greatest velocity. E.g. C
= F/d, C = d/d For most materials and most other objects, the speed of light is the same for all
wavelengths. The speed of light is, on the other hand, the speed of the energy form of a photon.
E.g. c = C/d, C = e/d For most particles, light travels over one degree of separation and this is
how photons interact with other particles. We can compare a particle's velocity to an object's
velocity. The speed of light is measured by the distance between the particle's nose and the surface
of the object. For example, a photon of light emits the energy of a single photon. If a photon of
another type is fired at the same speed as the first, it will get out of the light, but a photon of
the other type will not get back to the ground. The fractional energy will be reduced. The distance
between two photons of the same type will be reduced to the square of their energies. E.g. C = C/C2,
C = -D/d., D = 9/6 A photon of color does not have sufficient energy to be emitted by that color and
is therefore subject to The speed of light is the change in velocity over time. This is a constant,
but sometimes it is possible to express it like this: E = c2/e In relativity, the length of the
distance is the length of time the length of wave is divided by the speed of light. E.g. a beam of
light travelling at about 9.2 miles per second must travel at around 7.3 miles per second to get
E.g. a beam moving at 3.2 miles per second must travel at around 8 miles per second to get E.g. a
beam moving at 1.8 miles per second must travel at 9.0 miles per second to get E.g. an object going
at 2.3 miles per second must travel at 1.8 miles per second to get E.g. a beam moving at 2.3 miles
per second must travel at 3.4 miles per second to get E.g.. a beam traveling at 3.4 miles per second
to get E.g.. a beam moving at 2.3 miles per second must travel at 3.8 miles per second to get E.g..
a beam traveling at 3.8 miles per second to get E.g.. a beam moving at about 4.4 miles per second
must travel at about 3.9 miles per second to get E.g.. a beam moving at 5.5 miles per second to get
a beam moving at 5.9 miles per S.G.D.. is the same thing as a mass. The distance is a unit in terms
of the speed of light. Determining the speed of light is an additional measure of the energy. For
most things
result = model.generate(prompt, seed=20190903)print('\n'+ textwrap.fill(result, width=54))
2025-02-09 15:27:13,333 - DEBUG - Generated 1015 tokens in 15.81 seconds (64.21 tokens/sec)
Albert Einstein's Theory of Relativity stated that the
speed of light is the same as it is in two places,
which means that a given speed can either be described
by two different speed equations directly or they may
be both equations. It is then assumed that the speed
of light is the speed of the universe or the
universe's existence relative to Earth. In relativity,
a measure of the speed of light is the absolute speed
of the light. As long as the speed of light is less
than its speed in two different places, the absolute
speed can be calculated. For example, the absolute
speed is 1/2990000000 (2,299,792,458) km/hr with an
absolute speed about 10 times as fast as it is in two
different places. Now we can use the following
equation to describe the speed of light: E = C/C2 The
speed of light, as a function of C, is a constant. By
Einstein's definition of relativity, the speed of
light is a constant. This is because light travels at
its maximum speed along the direction (if it's
travelling above the speed of light, the point where
light must be observed is called "aperture" of the
speed of light). The speed of light is about half as
fast as the speed of light because the speed of light
has a smaller varying velocity for each direction of
radiation. The speed of light, as a function of C, is
a constant. The speed of a wave is the constant
measured along the direction of the wave relative to
its location in space. E = C/C2 where E is the speed
of light along the direction of the wave. Because the
speed of the wave is the speed of the particle in the
wave, and c the speed of the particle, E's is also
given by the speed of light. For example, a light
particle is moving from its place of greatest velocity
to its location of greatest velocity. E.g. C = F/d, C
= d/d For most materials and most other objects, the
speed of light is the same for all wavelengths. The
speed of light is, on the other hand, the speed of the
energy form of a photon. E.g. c = C/d, C = e/d For
most particles, light travels over one degree of
separation and this is how photons interact with other
particles. We can compare a particle's velocity to an
object's velocity. The speed of light is measured by
the distance between the particle's nose and the
surface of the object. For example, a photon of light
emits the energy of a single photon. If a photon of
another type is fired at the same speed as the first,
it will get out of the light, but a photon of the
other type will not get back to the ground. The
fractional energy will be reduced. The distance
between two photons of the same type will be reduced
to the square of their energies. E.g. C = C/C2, C =
-D/d., D = 9/6 A photon of color does not have
sufficient energy to be emitted by that color and is
therefore subject to The speed of light is the change
in velocity over time. This is a constant, but
sometimes it is possible to express it like this: E =
c2/e In relativity, the length of the distance is the
length of time the length of wave is divided by the
speed of light. E.g. a beam of light travelling at
about 9.2 miles per second must travel at around 7.3
miles per second to get E.g. a beam moving at 3.2
miles per second must travel at around 8 miles per
second to get E.g. a beam moving at 1.8 miles per
second must travel at 9.0 miles per second to get E.g.
an object going at 2.3 miles per second must travel at
1.8 miles per second to get E.g. a beam moving at 2.3
miles per second must travel at 3.4 miles per second
to get E.g.. a beam traveling at 3.4 miles per second
to get E.g.. a beam moving at 2.3 miles per second
must travel at 3.8 miles per second to get E.g.. a
beam traveling at 3.8 miles per second to get E.g.. a
beam moving at about 4.4 miles per second must travel
at about 3.9 miles per second to get E.g.. a beam
moving at 5.5 miles per second to get a beam moving at
5.9 miles per S.G.D.. is the same thing as a mass. The
distance is a unit in terms of the speed of light.
Determining the speed of light is an additional
measure of the energy. For most things
result = model.generate(prompt, seed=20190903)print('\n'+ textwrap.fill(result, width=40))
2025-02-09 15:27:56,654 - DEBUG - Generated 1015 tokens in 15.89 seconds (63.87 tokens/sec)
Albert Einstein's Theory of Relativity
stated that the speed of light is the
same as it is in two places, which means
that a given speed can either be
described by two different speed
equations directly or they may be both
equations. It is then assumed that the
speed of light is the speed of the
universe or the universe's existence
relative to Earth. In relativity, a
measure of the speed of light is the
absolute speed of the light. As long as
the speed of light is less than its
speed in two different places, the
absolute speed can be calculated. For
example, the absolute speed is
1/2990000000 (2,299,792,458) km/hr with
an absolute speed about 10 times as fast
as it is in two different places. Now we
can use the following equation to
describe the speed of light: E = C/C2
The speed of light, as a function of C,
is a constant. By Einstein's definition
of relativity, the speed of light is a
constant. This is because light travels
at its maximum speed along the direction
(if it's travelling above the speed of
light, the point where light must be
observed is called "aperture" of the
speed of light). The speed of light is
about half as fast as the speed of light
because the speed of light has a smaller
varying velocity for each direction of
radiation. The speed of light, as a
function of C, is a constant. The speed
of a wave is the constant measured along
the direction of the wave relative to
its location in space. E = C/C2 where E
is the speed of light along the
direction of the wave. Because the speed
of the wave is the speed of the particle
in the wave, and c the speed of the
particle, E's is also given by the speed
of light. For example, a light particle
is moving from its place of greatest
velocity to its location of greatest
velocity. E.g. C = F/d, C = d/d For most
materials and most other objects, the
speed of light is the same for all
wavelengths. The speed of light is, on
the other hand, the speed of the energy
form of a photon. E.g. c = C/d, C = e/d
For most particles, light travels over
one degree of separation and this is how
photons interact with other particles.
We can compare a particle's velocity to
an object's velocity. The speed of light
is measured by the distance between the
particle's nose and the surface of the
object. For example, a photon of light
emits the energy of a single photon. If
a photon of another type is fired at the
same speed as the first, it will get out
of the light, but a photon of the other
type will not get back to the ground.
The fractional energy will be reduced.
The distance between two photons of the
same type will be reduced to the square
of their energies. E.g. C = C/C2, C =
-D/d., D = 9/6 A photon of color does
not have sufficient energy to be emitted
by that color and is therefore subject
to The speed of light is the change in
velocity over time. This is a constant,
but sometimes it is possible to express
it like this: E = c2/e In relativity,
the length of the distance is the length
of time the length of wave is divided by
the speed of light. E.g. a beam of light
travelling at about 9.2 miles per second
must travel at around 7.3 miles per
second to get E.g. a beam moving at 3.2
miles per second must travel at around 8
miles per second to get E.g. a beam
moving at 1.8 miles per second must
travel at 9.0 miles per second to get
E.g. an object going at 2.3 miles per
second must travel at 1.8 miles per
second to get E.g. a beam moving at 2.3
miles per second must travel at 3.4
miles per second to get E.g.. a beam
traveling at 3.4 miles per second to get
E.g.. a beam moving at 2.3 miles per
second must travel at 3.8 miles per
second to get E.g.. a beam traveling at
3.8 miles per second to get E.g.. a beam
moving at about 4.4 miles per second
must travel at about 3.9 miles per
second to get E.g.. a beam moving at 5.5
miles per second to get a beam moving at
5.9 miles per S.G.D.. is the same thing
as a mass. The distance is a unit in
terms of the speed of light. Determining
the speed of light is an additional
measure of the energy. For most things
lolwat.
Parallel PyTorch Inference
Now that we’ve got generation working, let’s tackle the fun part: investigating whether or not PyTorch model inference can be done simultaneously, in parallel, by multiple threads, running on multiple cores at the same time, in a single free-threaded Python process. And ideally, we should only need one GPU, with all of these threads sharing it as fairly as possible. Although if we have multiple GPUs, we should be able to distribute the incoming work evenly across those too, if we want.
This is novel, uncharted territory we’re able to now explore thanks to the free-threaded version of Python.
Simultaneous vs Parallel vs Concurrent
The phrasing I’ve used above—“simultaneously, in parallel”—is a bit redundant. They both imply the same thing. When I use either word in this post, I’m explicitly referring to the new ability unlocked by free-threaded Python, where multiple threads can be running Python code on different CPU cores at the same time—i.e. simultaneously. And thus, you’re performing work in parallel.
When I use the term concurrent or concurrency, I’m using it in the traditional sense within the context of Python: making progress on multiple things at a time. This term is well suited to describe things like web servers, where a single Python process, with a single thread, running on a single CPU core, can service multiple clients at any given time (by way of non-blocking socket I/O and event multiplexing), but it’s not servicing any of those clients simultaneously on different cores, because that would require multiple threads running in parallel.
So how do we test this out? I guess we could spin up some threads and have them all call model.generate(), but that’s a little boring.
Why not try implement a pure Python, multi-threaded asyncio-based HTTP server, expose a /generate-esque style GET endpoint, and wire that up to the model generation code we saw above, allowing us to serve web requests for generation in parallel, ideally in an asyncio-friendly manner, where we can stream individual tokens back one-by-one via HTTP’s chunked-encoding transfer protocol, giving each thread’s event loop the ability to service multiple clients concurrently in a reasonably fair manner, whilst also servicing many clients in parallel across all threads, and for kicks, get AI to whip up a janky little React Bootstrap UI that we can slap in front of it all to test it?
Pure Python Multi-threaded HTTP Server
First thing we need is a nice and simple asyncio-based HTTP server, that also happens to work with multi-threading now that we have a free-threaded Python at our disposal.
Luckily, I have one of those laying around already! In support of another article I’m actively working on (which was meant to get published before this post), I ported the HTTP Server I wrote many years ago for PyParallel2 to use the new asyncio facilities available with Python, and then slapped multiple threads on it, where each thread gets its own asyncio event loop.
Turns out, thankfully, that this Just Works, at least on Linux3—we can now have a pure Python HTTP server, running in a single Python process, that’ll happily saturate every CPU core under load.
The server code lives in parallelopedia.http.server, and it includes a super janky little notion of “HTTP Apps”, the purpose of which can be best demonstrated with a simple example:
class PlaintextApp(HttpApp):@routedef plaintext(self, request): response = text_response(request, 'Hello, World!')self.server.send_response(response)class SleeperApp(HttpApp):@routedef sleep(self, request, seconds): time.sleep(int(seconds)) msg =f'Slept for {seconds} seconds.') response = text_response(request, msg)self.server.send_response(response)# Create a server with the two apps.app_classes=[PlaintextApp, SleeperApp]server = HttpServer(app_classes=app_classes)
In the above example, the HTTP server will route requests for the /plaintext endpoint to an instance of the PlaintextApp’s plaintext() routine, and /sleep requests get routed to the SleeperApp’s sleep() routine.
The “slapped multiple threads on it” activity I refered to earlier is handled by some new async and threading scaffolding added to the bottom of that module, with the pertinent pieces reproduced below:
asyncdef main_async( args: argparse.Namespace, protocol_class: type,*protocol_args: Tuple) ->None:""" This is the main function for the server when it is running in asynchronous mode. It will create a server instance and then call serve_forever() on it. Arguments: args (argparse.Namespace): Supplies the command-line arguments. protocol_class (type): Supplies the protocol class to use. protocol_args (tuple): Supplies the arguments to pass to the protocol class constructor. """ loop = asyncio.get_running_loop()if os.name in ('nt', 'cygwin'): reuse_port =Falseelse: reuse_port =True reuse_address =True server =await loop.create_server(lambda: protocol_class(*protocol_args), args.ip, args.port, backlog=args.listen_backlog, reuse_address=reuse_address, reuse_port=reuse_port, )asyncwith server:await server.serve_forever()def start_event_loop( args: argparse.Namespace, protocol_class: type,*protocol_args: Tuple) ->None:""" This function will start the asyncio event loop and run the main_async() function. It is intended to be the target of a threading.Thread. Arguments: args (argparse.Namespace): Supplies the command-line arguments. protocol_class (type): Supplies the protocol class to use. protocol_args (tuple): Supplies the arguments to pass to the protocol class constructor. """ asyncio.run( main_async( args, protocol_class,*protocol_args, ), debug=args.debug, )def main_threaded_multi_accept( args: argparse.Namespace, protocol_class: type,*protocol_args: Tuple) ->None:""" This is the main function for the server when it is running in multi-threaded mode with multiple accept sockets. Each thread will have its own asyncio loop issue a create_server() call for the given host/port and protocol. Arguments: args (argparse.Namespace): Supplies the command-line arguments. protocol_class (type): Supplies the protocol class to use. protocol_args (tuple): Supplies the arguments to pass to the protocol class constructor. """import threading threads = []for _ inrange(args.threads): thread = threading.Thread( target=start_event_loop, args=(args, protocol_class, *protocol_args), ) threads.append(thread) thread.start()for thread in threads: thread.join()def main(args: Optional[argparse.Namespace] =None):""" Main entry point for parallelopedia.http.server module. """ args = parse_arguments() logging.basicConfig( level=getattr(logging, args.log_level),format='%(asctime)s - %(levelname)s - %(message)s', )# Use multiple threads to load the application classes. app_classes = get_classes_from_strings_parallel( args.app_classes, ) protocol_class = get_class_from_string(args.protocol_class) protocol_args = (app_classes,)if args.threads ==1: asyncio.run( main_async( args, protocol_class,*protocol_args, ), debug=args.debug, )else: main_threaded_multi_accept( args, protocol_class,*protocol_args, )if__name__=='__main__': main()
GPT2 HTTP App
With the HTTP server scaffolding in place, we can now whip up a little Gpt2App class that has a generate() method. Incoming requests to the /generate endpoint will be routed to that routine by the server.
Synchronous Up-Front Generation
Now, we could take the simple approach, where the Gpt2App.generate() call goes off and calls model.generate() and then patiently waits for the entire response to be generated before sending anything back to the user.
That code would look something like this:
class Gpt2App(HttpApp):@routedef generate(self, request: Request,*args: List,**kwds: Dict) ->None: prompt = args[0] model = get_model() result = model.generate(prompt=prompt) respose = text_response(request, result)self.server.send_response(response)
But when have you ever interacted with an LLM via a web interface where it waits until it generates all of the response up-front before sending it back to you? Never; you can see it generate the response in real time, and that’s what we want to mimic here in this experiment.
Our Goals
The high-level goals for our solution are thus:
We want to send a generated token back to the user as soon as it becomes available.
We want to ensure the client receiving the token can display it as soon as they receive it—so we need to be cognizant of what HTTP transfer protocol we use to send bytes back. If we just used normal HTTP transfer encoding, the client would wait until we’ve sent everything before the user sees it, despite the fact that we’ve been trickling individual tokens to them the entire time.
We want to play nicely with the asyncio ecosystem upon which our hosting HTTP server is based—so we need to be cognizant of the current thread’s event loop, and make sure we don’t impede that thread’s ability to concurrently serve other clients that are being handled by the event loop.
Thankfully, as we saw earlier with the implementation of the GPT.generate() routine, generating tokens in response to a prompt is inherently a token-by-token process. So the algorithm at least provides us with the means to obtain a single token at a time, which takes care of the first point.
Second, HTTP’s chunked-encoding transfer protocol will allow a HTTP client to immediately “see” the tokens we send back to it as soon as we send them, provided we enable TCP_NODELAY on the underlying socket to ensure the operating system sends the bytes out to the client as soon as we send them.
Note
If we didn’t do this, the default behavior of Nagle’s algorithm would apply, and the operating system would delay sending individual bytes back when we request, in the hope that it can accumulate more bytes to send all at once at a slightly later point in time. This is advantageous for maximizing throughput, but it impedes latency, and in our case, we want the lower latency afforded by immediately sending the bytes back to the client as soon as we generate them.
Chunked-encoding works by setting an HTTP response header Transfer-Encoding: chunked, and then in the body, each chunk is transmitted by its length and then the chunk itself. The server communicates to the client that the transfer has completed once a zero-length chunk is received.
So, as long as we send our tokens back via chunked-encoding, any HTTP/1.1 client will be able to reassemble them back into the generated text, giving the visual appearance of real time model generation. That will take care of the second point.
Lastly, in order to play nice within the greater asyncio ecosystem, we need to give control back to the underlying thread’s asyncio event loop after we generate a token and yield a decoded text fragment, which can thankfully be done via a simple call to await asyncio.sleep(0), provided we’re generating text from the model from within an async callback.
This ensures multiple concurrent clients being handled by our thread’s event loop will be handled fairly; they’ll all receive generated tokens at the same rate.
Asynchronous Token-by-Token Generation
The first thing we need to do is to change our Gpt2App.generate() call into something that is async compatible, in anticipation of some later code that we write needing to issue an await asyncio.sleep(0), which can only be done within a call frame of an asynchronous method.
When our Gpt2App.generate() routine is called, we’re still within the context of the asyncio protocol’s data_received() routine, which is a normal, synchronous method, not an enlightened async method that can participate in an asyncio event loop.
So, in order to transition from a synchronous callback to an asynchronous one, we can use the current event loop’s create_task() routine to enqueue an async method for execution.
Step 1: Have generate() enqueue an async generate_response().
Thus, our Gpt2App.generate() call will look something like this:
class Gpt2App(HttpApp): ...@routedef generate(self, request: Request,*args: List,**kwds: Dict) ->None:# Extract the "prompt" provided in the incoming request. text = args[0]# Obtain the event loop and schedule the response# generation via our async generation coroutine.# We have to do it like this as at this point we're# still within the call frame of the data_received()# protocol callback, which isn't an async function. loop = asyncio.get_running_loop() async_task =self.generate_response(request, text, **kwds) loop.create_task(async_task)
Step 2: Implement an async generate_response()
Our asynchronous generate_response() routine will be the bridge between generating tokens from the model, and sending those tokens back to the client via chunked-encoding.
It is responsible for preparing the response to use chunked-encoding, and then enabling TCP_NODELAY on the socket.
Then, assuming that our model has an async_generate_for() routine, which we’ll implement in the next step, we perform an async for loop over that routine in order to obtain individual decoded tokens. As soon as we receive a token, we can send it back to the client via the response object’s send_chunk() routine.
Once we’ve exhausted the async generator (i.e. it either generated the maximum number of requested tokens, or it encountered a stop token), we can re-enable TCP_NODELAY, and then return.
A simplified version of the Python code is presented below. I have omitted most of the error handling and query parameter parsing code for simplicity; see the expandable code block at the end for the full version.
class Gpt2App(HttpApp): ...asyncdef generate_response(self, request: Request, prompt: str,**kwds: Dict ) ->None:# Prepare a chunked-encoding response. response = request.response response.code =200 response.message ='OK' response.chunked_response =True response.content_type ='text/plain'# Obtain the model. model = get_model()# We want to enable TCP_NODELAY for the duration of# the response. This ensures packets are sent# immediately without any internal buffering.try: response.enable_tcp_nodelay() enabled_nodelay =TrueexceptExceptionas e: logging.error(f'Error enabling TCP_NODELAY: {e}') enabled_nodelay =False# Write the chunked header immediately. response_bytes =bytes(response)ifnotself.write(response_bytes):# Encountered a disconnect, return.return# N.B. From herein, all data must be transferred to# the client via chunked encoding with the# `response.send_chunk()` routine.# Send the initial prompt text. response.send_chunk(prompt)# Obtain decoded tokens from the model one at a time# via an `async for` loop, sending the token back to# the client as soon as it's available.asyncfor decoded_token in model.generate_async_for(prompt): response.send_chunk(decoded_token)# Terminate the chunked-encoding response. response.end_chunks()# Disable TCP_NODELAY now that the response is complete.# The reasoning behind this is that the client may have# issued the HTTP request with a keep-alive header, and# plans on reusing this connection for a different request# next, which won't necessarily want `TCP_NODELAY` active.if enabled_nodelay: response.disable_tcp_nodelay()
Full Code for async Gpt2App.generate_response()
The actual code has more robust error-handling facilities and support for extracting the query string parameters from the incoming request URI and converting them into keyword arguments suitable for passing to the model.
Additionally, we haven’t touched on how we initialize or obtain instances of our models yet, so the model-related code won’t make much sense until later in the article.
class Gpt2App(HttpApp): routes = make_routes() route = router(routes)def__init__(self, server: HttpServer) ->None:super().__init__(server)self.printable = PRINTABLEdef is_connected(self):# server.transport will be severed when the client# disconnects, so we can use this to determine if# the client is still connected. server =self.server transport =Nonetry: transport = server.transportexceptAttributeError:passreturn transport isnotNonedef write(self, response_bytes): server =self.server transport =Nonetry: transport = server.transportexceptAttributeError:passif transport isnotNone: transport.write(response_bytes)returnTrueelse:returnFalseasyncdef generate_response(self, request: Request, prompt: str, **kwds: Dict ) ->None: response = request.response response.code =200 response.message ='OK' response.chunked_response =True response.content_type ='text/plain'if kwds isNone: kwds = {} max_length =min(int(kwds.get('max_length', 100) or100), 1024) top_k =min(int(kwds.get('top_k', 50) or50), 50) seed = kwds.get('seed', None)if seed:try: seed =int(seed)exceptValueError: seed =Noneifnot seed: seed = random.randint(0, 2**32-1) device = kwds.get('device', None) model_name = kwds.get('model', None)if model_name =='gpt2-xl': models = PRETRAINED_MODELS get_next = get_next_pretrained_modelelse: model_name ='gpt2' models = MODELS get_next = get_next_model model =Noneif device isnotNone:if device =='cpu': model = models[-1]elif device.startswith('cuda:'):try: index =int(device[5:])exceptValueError: index =-1if index <0or index >= NUM_GPUS: index =-1if index !=-1: model = models[index]elif device =='cuda': model = models[random.randint(0, NUM_GPUS -1)]ifnot model:# Get a model. If there are multiple models available, e.g. if we# have multiple GPUs, this will balance the load a bit. model = get_next() expose_headers = ('Access-Control-Expose-Headers: ''X-Max-Length, ''X-Top-K, ''X-Seed, ''X-Model-Name, ''X-Model-Device' ) response.other_headers.extend([ expose_headers,f'X-Max-Length: {max_length}',f'X-Top-K: {top_k}',f'X-Seed: {seed}',f'X-Model-Name: {model_name}',f'X-Model-Device: {model.device}', ])# We want to enable TCP_NODELAY for the duration of# the response. This ensures packets are sent# immediately without any internal buffering.try: response.enable_tcp_nodelay() enabled_nodelay =TrueexceptExceptionas e: logging.error(f'Error enabling TCP_NODELAY: {e}') enabled_nodelay =False# Write the chunked header immediately. response_bytes =bytes(response)ifnotself.write(response_bytes):# Encountered a disconnect, return.return# N.B. From herein, all data must be transferred to# the client via chunked encoding with the# `response.send_chunk()` routine.# Send the initial prompt text. response.send_chunk(prompt)# Obtain an async generator instance to the model's# new async token generation routine. generate_tokens = model.generate_async_for( prompt, max_length=max_length, top_k=top_k, seed=seed, )asyncfor decoded_token in generate_tokens:if decoded_token ==-1:# A non-printable token was generated,# terminating generation. response.send_chunk( OOPS_NON_PRINTABLE_ENCOUNTERED )break# If the client has forcibly disconnected,# terminate generation.ifnotself.is_connected():break# Otherwise, send the decoded token to the client# via chunked encoding. response.send_chunk(decoded_token)# Send the termination chunk. This may fail at the# socket.send() level if the client has already# disconnected, which is harmless. response.end_chunks()# Disable TCP_NODELAY now that the response is complete.# Again, this may fail at the socket level if the client# has already disconnected, which is harmless.if enabled_nodelay:try: response.disable_tcp_nodelay()exceptExceptionas e: msg =f'Error disabling TCP_NODELAY: {e}' logging.error(msg)
Step 3: Implement an async GPT.async_generate_for()
In the code example above, we assumed the GPT model we’ve been using had grown a new async routine named async_generate_for(), which we’ll cover now.
This routine is essentially an asyncio-friendly version of the original generate() routine we wrote. It shares a lot of the same code, with a few notable tweaks in order to support the fact that it is being called from a callback that was enqueued on a thread’s asyncio event loop, and it is expected to yield a token as soon as it is available, and then pass control back to the event loop in order for it to service other clients before it continues with generating the next token.
It also has the notion of checking for “printable” characters. This came about when I was initially testing this code via curl which would sometimes balk and exit in the middle of streaming the response, citing that it encountered corrupted data or something like that.
After investigation, it turned out that sometimes, for whatever reason, the model just generates a junk, nonsensical token (like 65534, which is well outside the highest token number of 50384). I have no idea why it happens, although I’ll note it happens on the OpenAI GPT2 XL model available on HuggingFace (which we’ll discuss later) too, so, eh.
I deal with this by checking if we’ve generated a non-printable token after decoding it, and, if so, return -1 and terminate the loop. The full version of the Gpt2App.generate_response() routine that we introduced above checks if we return -1, and if so, terminates generation with an oopsie message, e.g.:
OOPS_NON_PRINTABLE_ENCOUNTERED = ('Oops! Non-printable token encountered. Generation terminated.')...asyncfor decoded_token in generate_tokens:if decoded_token ==-1:# A non-printable token was generated,# terminating generation. response.send_chunk( OOPS_NON_PRINTABLE_ENCOUNTERED )break
After yielding a valid decoded token, we issue an await asyncio.sleep(0) call, which returns control back to the event loop for it to potentially handle other concurrent clients. If there are no other clients, or after it has handled all other enqueued work, generation resumes.
The full code follows, it is simple enough as-is that I didn’t feel the need to omit any details like in the prior example.
class GPT: ...asyncdef generate_async_for(self, text: str, max_length: int=1024, top_k: int=50, seed: int=None, ):""" Asynchronously generate text from the model, yielding tokens one at a time as soon as they are available. Args: text (str): Supplies the prompt. max_length (int): Supplies the maximum total length, including prompt. top_k (int): Supplies the number of tokens to consider at each generation step. seed (int): Optionally supplies the manual seed to use for the generator. If None, the model's manual seed will be used. Yields: byte: The newly generated decoded token. If -1, a non-printable token was generated, and generation was terminated. """ enc =self.enc stop_token =self.stop_token# Encode the prompt -> tensor of shape (1, T) tokens = enc.encode(text) x = torch.tensor( tokens, dtype=torch.long, device=self.device ).unsqueeze(0) sample_rng = torch.Generator(device=self.device)if seed isNone: seed =self.manual_seed sample_rng.manual_seed(seed) logging.debug(f'[generate_async_for] Starting generation loop for {text} 'f'with seed {seed}.' ) start_time = time.perf_counter() count =0while x.size(1) < max_length: count +=1with torch.no_grad():# Forward pass, ignoring the returned loss. (logits, _) =self(x)# Take the logits at the last time-step (shape:# (1, vocab_size)). logits = logits[:, -1, :]# Convert to probabilities. probs = F.softmax(logits, dim=-1)# Top-k sampling. topk_probs, topk_indices = torch.topk( probs, k=top_k, dim=-1, )# Sample the next token. next_idx = torch.multinomial( topk_probs, num_samples=1, generator=sample_rng, ) next_token = torch.gather(topk_indices, -1, next_idx)# If the next token is the stop token, we're done. next_token_item = next_token.item()if next_token_item == stop_token:break# Append token to current sequence. Although we only# yield a singular decoded token below, we still need# to keep track of the entire sequence for subsequent# generation steps. x = torch.cat((x, next_token), dim=1)# Decode the newly-generated token. Note that a single# token will often be decoded to multiple characters. new_text_fragment = enc.decode([next_token.item()])# If any of the characters in the decoded text# representation aren't printable, terminate# generation.ifnotall(c inself.printable for c in new_text_fragment):yield-1breakyield new_text_fragment# Yield control back to the event loop before continuing# generation. If we didn't do this, this client would# hog the thread's event loop, preventing other clients# associated with the loop from getting serviced. (As# we're now running multiple threads in parallel, other# clients associated with event loops on other threads# would not be impacted.)await asyncio.sleep(0) elapsed = time.perf_counter() - start_time logging.debug(f"[generate_async_for] Generated {count} tokens in "f"{elapsed:.2f} seconds (~{count / elapsed:.2f} tok/s)" )
This routine was the last piece we needed to implement to satisfy our three goals captured earlier, so, we’re now ready to test it out!
Test Drive!
Launching the HTTP Server
We can launch an instance of our multi-threaded HTTP web server with our new Gpt2App HTTP application via the command line as follows:
This will start up a multi-threaded HTTP server listening on all interfaces on port 4444, with 40 threads, and two HTTP applications: our Gpt2App, which will service requests to the /generate endpoint, and a PlaintextApp that just returns “Hello, World!” to any incoming request received for the /plaintext endpoint.
Visualizing Chunked-Encoding
Let’s visualize the generation response in a way that shows us the raw chunked-encoding, without doing any client-side reassembly. We can achieve that with echo and netcat (nc):
HTTP/1.1 200 OK
Server: Parallelopedia Web Server v1.0
Date: Fri, 07 Feb 2025 23:32:02 GMT
Accept-Ranges: bytes
Content-Type: text/plain
Access-Control-Allow-Origin: *
Connection: close
Transfer-Encoding: chunked
Access-Control-Expose-Headers: X-Max-Length, X-Top-K, X-Seed, X-Model-Name, X-Model-Device
X-Max-Length: 20
X-Top-K: 50
X-Seed: 42
X-Model-Name: gpt2
X-Model-Device: cuda:0
13
The quick brown fox
3
is
2
a
4
sub
7
species
5
that
B
originated
3
in
9
southern
9
Scotland
3
as
2
a
8
variety
3
of
4
fox
1
.
5
This
0
As you can see, we’ve enabled chunked-encoding by way of the Transfer-Encoding: chunked header. And the body of the response is comprised of these “chunks”; specifically, each bit of decoded text is preceded by its length, in bytes, then followed by \r\n, then followed by the text itself.
The zero-length chunk at the end indicates completion of the transfer, and as we requested Connection: close in our headers, our HTTP server closes the connection once the generation has completed.
Note
In HTTP/1.1, “keep-alive” is the default behavior—i.e., a server won’t close a connection unless the client specifically requests it. This is the opposite of HTTP/1.0 behavior, where the server will close a connection by default unless a client furnishes a Connection: keep-alive header.
Verifying via Curl
If we switch over to curl and run the same generation request, we’ll see the reassembled text, and, provided we supply the --no-buffer argument, curl will also display decoded text as soon as it receives it.
% curl --no-buffer--verbose\'http://dgx:4444/generate/The%20quick%20brown%20fox?'\'max_length=20&seed=42&device=cuda'* Trying 10.0.132.48:4444...* Connected to dgx (10.0.132.48)port 4444 (#0)> GET /generate/The%20quick%20brown%20fox?max_length=20&seed=42&device=cuda HTTP/1.1> Host: dgx:4444> User-Agent: curl/7.81.0> Accept: */*>* Mark bundle as not supporting multiuse< HTTP/1.1 200 OK< Server: Parallelopedia Web Server v1.0< Date: Fri, 07 Feb 2025 23:05:34 GMT< Accept-Ranges: bytes< Content-Type: text/plain< Access-Control-Allow-Origin: *< Transfer-Encoding: chunked< Access-Control-Expose-Headers: X-Max-Length, X-Top-K, X-Seed, X-Model-Name, X-Model-Device< X-Max-Length: 20< X-Top-K: 50< X-Seed: 42< X-Model-Name: gpt2< X-Model-Device: cuda:2<The quick brown fox is a subspecies that originated in southern Scotland as a variety of fox. This* Connection #0 to host dgx left intact
Launching the React Bootstrap UI
The React Bootstrap UI can be launched as follows:
% cd $PARALLELOPEDIA_UI# i.e. root of gh:tpn/parallelopedia-ui% conda activate py313t # or whatever has recent nodejs/npm% npm start run
Note
Full disclaimer: I’m not a web developer. I don’t know JavaScript, React, or Bootstrap, or anything else in the modern web stack. So like any good developer in 2025 when confronted with a task they have absolutely no business doing… I palmed it off to AI—either ChatGPT, local LLMs via LM Studio, or, more recently, Aider.
TL;DR: the web interface code probably sucks.
Running npm start run should open up a browser window pointing at http://localhost:3000. Ignore the Wiki tab—that’s for another article—and switch to the GPT2 tab. You should see something similar to the GIF below, which is a demo of me using the interface to generate some text:
Not too shabby! Granted, it’s a tad jarring seeing characters per second instead of tokens per second. If we wanted to display a live tokens/sec rate, some possible options might be:
Alter GPT.async_generate_for() to calculate how long it took to generate each token (whilst we’re in the main generation loop), convert that into a tokens/sec rate, and then change our API on both the server side and the JavaScript UI side such that each chunk that gets sent back is encoded with both the rate and the actual chunk. A drawback of this approach is that hitting the /generate endpoint with curl or netcat would look wacky, as you’d see token/sec floats in between each generated set of decoded characters.
Send the raw integer tokens back as they’re generated, instead of first decoding them. That would allow client-side JavaScript to calculate tokens/sec, but it would make it impossible to easily inspect the output with our command line tools like curl or echo ... | nc. It would also require adding a JavaScript decoding library on the client side (although that’s not such a big deal, I think you can just do npm install tiktoken these days).
Have the JavaScript client re-encode the decoded characters received back into their GPT2-tokenizer token representation in order to determine actual underlying token count, and recalculate the rate based off that. This rate would differ from the tokens/sec rate observed on the server because it would also include network latency.
Do something more advanced with a separate WebSocket or something, where the live tokens/sec generation rate can be displayed independently to the generated decoded tokens.
Just YOLO it and have the JavaScript code divide the characters per second by, say, three, and pretend that’s roughly the tokens/sec rate (assuming that on average, one token represents three characters in any given response).
I don’t want to do any of that! So characters per second it is, for now, however jarring it may be.
Parallel Generation
Alright, we’ve exercised the /generate endpoint a few different ways via a single client connection, and it looks like it’s behaving the way we hoped it would, albeit with a whopping simultaneous client count of, precisely, one.
Let’s look at introducing some simultaneous load by way of multiple machines all issuing /generate requests at the same time, first via the echo ... | nc approach, and then via curl.
I’ll leverage tmux and the :setw synchronize-panes command, which sends my console keystrokes to all active panes within a given tmux session window.
Netcat Example
I captured a 15 frames-per-second GIF of this in action, which you can view below. It shows terminal sessions to six machines, arranged vertically, all executing the same echo ... | nc command depicted above.
Note
This might seem like a bit of a silly test, especially when it’s working correctly, because the output is pretty benign, and we don’t really know that these requests were actually being served simultaneously by different threads on the server side.
When I first tried it, though, it absolutely did not work correctly—one server would get correct output, another would get the HTTP headers in triplicate, whilst two other sessions just went straight into the chunked responses, with no preceding HTTP headers in sight. I had a bug in my HTTP routing code (that exists in the PyParallel code I wrote ten years ago upon which I based the new HTTP server on!) which was entirely to blame. With that code ripped out and @route reimplemented in a much simpler fashion, everything worked well.
Parallel Netcat Generation
Parallel Netcat Generation
I extracted the last frame of the GIF, below, where you can see that at least there was some variance between which GPUs were selected:
Parallel Netcat Generation - Last Frame
Curl Example
I did the same thing with the panes arranged horizontally, using curl instead and a max_length=1000 and no seed, which helps in visualizing the parallel generation, as you can clearly see different clients are receiving completely different outputs.
Parallel Curl Generation
Parallel Curl Generation
Let’s Load Test This Sucker!
So far so good! Looks like we can absolutely do PyTorch model generation in parallel via multiple threads thanks to free-threaded Python.
Let’s ramp it up a notch and see what happens if we attempt to load test the /generate endpoint from a different server. The leopard server you see in the next example is an Intel(R) Xeon(R) W-2275 CPU @ 3.30GHz (14 cores, 28 threads) and is connected to the dgx box via a 10G Ethernet link.
Using the fork of wrk I hacked together some ten years ago or so whilst load testing PyParallel, let’s kick off a run with 14 threads for 30 seconds. There will be one connection per thread, and HTTP/1.1 will be used, so the implicit keep-alive behavior means that as soon as one /generate request completes, another one is immediately dispatched.
For posterity, the console command being used is as follows:
% time ./wrk -v--latency-c14-t14-d30s\'http://dgx:4444/generate/''The%20quick%20brown%20fox?max_length=20&device=cuda&seed=42'
In the GIF below you’ll see two terminal windows. The smaller foreground window on the left is leopard, the session doing the wrk load test. The background window that takes up the rest of the screen is logged in to dgx, our server, and it’s running a GPU-enabled build of btop, which is a beautiful console app for displaying resource usage. I particularly like btop in this example as it does a good job of conveying the CPU and GPU load that kicks in when the load test starts.
The 40 CPU cores can be seen in the top pane, and, as expected, about 14 of them whirr into life as soon as the load test begins, which tracks, as there are now 14 clients attempting to hammer our /generate end point.
Below that, you can see the four Tesla V100-DGXS-32GB GPUs also whirr into action, handling the generation between them with relative ease.
On the bottom right, I’ve filtered the process tree to just display our python HTTP server invocation. (I would love it if btop showed the individual threads and their load, as it would beautifully depict Python saturating cores in parallel now that there’s no GIL impeding execution, however, I don’t believe that’s currently possible with this version of btop.)
Parallel Load Testing (No GIL)
Parallel Generation Load Test: No GIL
Parallel Generation Load Test - No GIL
I’ve extracted the last frame, below, to allow closer inspection at the end of the run, without the annoyance of the GIF looping.
Parallel Generation Load Test No GIL - Last Frame
The console output from wrk is captured in the callout below.
A visualization of the latencies is presented below. All code for visualization is from the Data Visualization Jupyter Notebook, which you can preview here, or access directly on Github.
Parallel Load Test Latency Distribution (No GIL)
Parallel Load Test Latency Distribution (No GIL)
Ablation Test: Re-enable the GIL
Let’s see what happens if we re-enable the GIL. We should see poor resource utilization on the server side, as only one Python thread can run at a time, and the statistics reported on the client side should be equally dismal.
Expand the callout below to view the GIF. I have used another terminal window to launch the server with the -Xgil=1 argument, which re-enables the GIL. I then switch back over to leopard and perform the wrk load test like before.
Parallel Generation Load Test: GIL Enabled
Parallel Generation Load Test - GIL Enabled
As with before, I’ve extracted the last frame, below, to allow closer inspection at the end of the run, without the annoyance of the GIF looping.
Parallel Generation Load Test - GIL Enabled - Last Frame
As we expected: re-enabling the GIL absolutely murders resource utilization.
The console output from wrk follows in the next callout.
And a visualization of the latencies follows. Note that there were 60 socket timeouts in this case, whereas no timeouts were encountered in the prior run with the GIL disabled.
Parallel Load Test Latency Distribution (GIL Enabled)
Parallel Load Test Latency Distribution (GIL Enabled)
Parallel Load Testing: No GIL vs GIL
Viewing them side by side:
Parallel Load Test Latency Distribution (No GIL & GIL Enabled Combined)
Parallel Load Test Latency Distribution (No GIL & GIL Enabled Combined)
Remember to keep in mind that the No GIL vs GIL requests/sec was 70.80 vs 7.35, nearly a ten-fold increase:
Parallel Load Test Requests/sec (No GIL vs GIL)
Parallel Load Test Requests/sec (No GIL vs GIL)
Parallel Load Testing: How does Plaintext Fare?
As we launched our HTTP server invocation earlier with the PlaintextApp class, which simply responds to the /plaintext endpoint with b'Hello World\n', let’s throw some load at that too.
This doesn’t have anything to do with PyTorch; it’s an orthogonal load test that’s fun because it depicts the stark difference between GIL vs no GIL.
The console command was issued on leopard, like last time, as follows:
% time ./wrk -v--latency-c14-t14-d30s'http://dgx:4444/plaintext'
Combined requests/sec visualization depicts a similar 10x improvement:
Plaintext Parallel Load Test Requests/sec (No GIL vs GIL)
Plaintext Parallel Load Test Requests/sec (No GIL vs GIL)
The latencies are a bit misleading when viewed in isolation, as you can see the tail end of the No GIL run incur higher latencies compared to the GIL run—however, as can be seen above, the No GIL run was doing 10x more requests/sec.
Plaintext Parallel Load Test Latency Distribution (No GIL vs GIL)
Plaintext Parallel Load Test Latency Distribution (No GIL vs GIL)
Parallel Load Test Summary
I think it’s safe to say we’ve achieved our original goals in a satisfactory manner. We can absolutely now do parallel model inference in a single Python process, thanks to free-threading, and the whole thing works great when wrapped around an asyncio-based HTTP server that yields tokens one-by-one as soon as they’re generated, which is a closer representation of what you’d want in the real world if you were to deploy such a thing.
My takeaway from all of this: free-threading kicks ass. Having devoted so much time toward a parallel Python solution with PyParallel over a decade ago, it makes me incredibly happy to see a working, performant solution finally getting mainlined into the core Python language.
It’s important to consider how many other things are simultaneously unlocked by Python free-threading. The HTTP server we’ve demonstrated above also has directory/file-serving behavior built-in, so it also functions like a normal web server would. In a separate article, I’ll introduce the Wiki server component, which features a WikiApp HTTP app that loads a 56GB XML file, plus about 12GB of supporting index structures (by way of datrie digital search tries, and NumPy arrays). This app happily loads in the same single Python process as our existing Gpt2App—demonstrating the power of accessing huge data structures in parallel by multiple threads, something that couldn’t be done with multiprocessing without paying for the cost to replicate that memory overhead in every process.
Finishing Up
I’ll tackle two more topics before we conclude this post. First, I want to briefly touch on a handful of the changes I made to the GPT-2 implementation after the version we introduced in the Initial Implementation section.
Second, let’s see what happens if we throw the new graph compilation optimizations in PyTorch 2.0 at our solution; specifically, can we still use model = torch.compile(model) in our free-threading solution?
Reviewing Initial Implementation
Skipping Weight Initialization
The first version of the GPT class we introduced here looked like this:
As I mentioned earlier, I ripped that almost verbatim from Andrej’s train_gpt2.py code. For about two weeks during development of this work, I was using a local free-threaded build of PyTorch, but had forgotten that I had built it in debug configuration.
Everything was dog slow, but I had no frame of reference for how long things should have been taking, having had no real prior experience with PyTorch, but I was getting annoyed enough at how long it seemed to take to load the model—which I was doing frequently during development—so I added a bunch of timing code around things to try and bisect where the overhead was being introduced.
Looking at the __init__() code above, though, I was thoroughly perplexed by the _init_weights(self, module) function. Why would we be initializing weights at the end of the GPT constructor if we were just about to override them with the weights we were loading from the checkpoint?
Eliminating the call to _init_weights() entirely shaved off 15 seconds from the time it took to simply create an “empty”GPT() instance—i.e. before we’d even loaded the weights. However, it was still taking 15 seconds just to execute this block of code:
I hoisted that code out into a separate _init_transformer() routine and surrounded it with some optional torch.profiler.profile() glue:
timer = ElapsedTimer()with timer:ifnotself.torch_profile_activities:self._init_transformer()else:with torch.profiler.profile( activities=self.torch_profile_activities, with_stack=True, ) as prof:self._init_transformer()self.torch_profile_init_transformer = prof msg =f'Initialized GPT model in {timer.elapsed:.3f} seconds.' logging.info(msg)
The profiling data yielded some interesting insight:
>print(model.torch_profile_init_transformer.key_averages().table(sort_by='cpu_time_total'))------------------------------------------------------------------------------------------------- Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls------------------------------------------------------------------------------------------------- aten::uniform_ 74.31%10.861s74.31%10.861s111.966ms97 aten::normal_ 25.64%3.747s25.64%3.747s1.873s2 aten::detach 0.01%1.408ms0.03%3.712ms24.914us149 aten::empty 0.02%2.570ms0.02%2.570ms17.247us149 detach 0.02%2.304ms0.02%2.304ms15.464us149 aten::fill_ 0.01%1.013ms0.01%1.013ms40.500us25 aten::zero_ 0.00%305.888us0.00%305.888us12.236us25 cudaDeviceSynchronize 0.00%16.073us0.00%16.073us16.073us1-------------------------------------------------------------------------------------------------Self CPU time total: 14.615s
So aten::uniform_ and aten::normal_ were taking up, literally, 99.95% of the time during the transformer initialization. That also seemed bananas for the exact same reason calling _init_weights() seemed bananas: why was so much time and effort being spent initializing distributions when we were going to immediately overwrite all the weights when we load the model from the checkpoint?
Now, granted, had I not been using a debug build of PyTorch, those 14.597 seconds spent on weight initialization would be more like unnoticeable milliseconds at best. But I didn’t know that at the time, so after a bit of digging, I found out I could subclass my Embedding and Linear layers (which were the ones contributing to the uniform and normal distribution setup time overhead) in such a way that weight initialization would be skipped entirely.
So, if you look at the gpt2.py implementation on Github, you’ll see that I’m using a bunch of NoInit classes, as follows:
# ==============================================================# Classes# ==============================================================# N.B. We have simple "no-init" overrides for nn.Embedding and# nn.Linear which skip the default initialization routines,# significantly reducing the time to load the model by# avoiding uniform and random distribution initialization.# As we immediately load all the weights from the model# checkpoint straight after creating the model, we don't# need the default initialization routines.class NoInitEmbedding(nn.Embedding):def reset_parameters(self):# Skip default uniform initialization.passclass NoInitLinear(nn.Linear):def reset_parameters(self):# Skip default Kaiming initialization.passclass CausalSelfAttention(nn.Module):def__init__(self, config):super().__init__()assert config.n_embd % config.n_head ==0# Key, query, value projections for all heads, but in a batch.self.c_attn = NoInitLinear(config.n_embd, 3* config.n_embd)# Output projection.self.c_proj = NoInitLinear(config.n_embd, config.n_embd)self.c_proj.NANOGPT_SCALE_INIT =1# Regularization.self.n_head = config.n_headself.n_embd = config.n_embd ...class MLP(nn.Module):def__init__(self, config):super().__init__()self.c_fc = NoInitLinear(config.n_embd, 4* config.n_embd)self.gelu = nn.GELU(approximate='tanh')self.c_proj = NoInitLinear(4* config.n_embd, config.n_embd)self.c_proj.NANOGPT_SCALE_INIT =1 ...class GPT: ...def _init_transformer(self):""" Initialize the transformer. """ config =self.configself.transformer = nn.ModuleDict(dict( wte=NoInitEmbedding(config.vocab_size, config.n_embd), wpe=NoInitEmbedding(config.block_size, config.n_embd), h=nn.ModuleList( [Block(config) for _ inrange(config.n_layer)] ), ln_f=nn.LayerNorm(config.n_embd), ) )self.lm_head = NoInitLinear( config.n_embd, config.vocab_size, bias=False, )self.transformer.wte.weight =self.lm_head.weight
A few weeks after I’d put this code in, I came across this HuggingFace article on Big Model Inference where they discuss the very problem I was hitting, which is problematic on much larger models even when you’re using a release build of PyTorch—not just pip-squeak GPT2 models with a debug build—and they have a decorator-oriented solution:
from accelerate import init_empty_weightswith init_empty_weights(): my_model = ModelClass(...)
Granted, I couldn’t use accelerate because it depends on packages that are not available for free-threaded builds yet, but I wanted to include the information here for future reference.
Can We Load Other Checkpoints?
GPT2 was the last model where OpenAI made the weights publicly available. So, in theory, I should be able to download their largest GPT2 model—GPT2 XL—figure out how to extract the weights, and then load them into our janky little GPT class instead of the ones from the locally-trained checkpoint.
Turns out downloading it is easy via huggingface-cli (again, not something that works with free-threading, so you’ll need to activate the py313 environment and pip install -U "huggingface_hub[cli]" per these instructions):
The pytorch_model.bin one is a pickled dict of tensors obtained from a torch.save() call—i.e. basically equivalent to the checkpoint we’d been using for the locally trained GPT2 model we used. For whatever reason, that one didn’t interest me much, but model.safetensors did.
Thankfully, I could pip install safetensors in the free-threaded py313t environment, so I wrote some helper glue to give me back a HuggingFaceModel with the tensors accessible via a safetensors attribute if I passed it the appropriate model name, e.g.:
@dataclassclass HuggingFaceModel: name: str config: dict safetensors: "safetensors.safe_open" tokenizer: dict tokenizer_config: dict vocab: dictdef get_huggingface_model(model_name: str) -> HuggingFaceModel:""" Returns a Hugging Face model object for the given model name. Args: model_name (str): Supplies the name of the Hugging Face model. This should be in the format of `namespace/model`, e.g. for GPT2 XL: `openai-community/gpt2-xl`. This will be expanded out to the following directory: `~/.cache/huggingface/hub/models--openai-community--gpt2-xl` Returns: HuggingFaceModel: Returns a HuggingFaceModel object containing the model name, configuration, and SafeTensors object. """ base = os.path.expanduser('~/.cache/huggingface/hub/models--') (namespace, model) = model_name.split('/') base_path =f'{base}{namespace}--{model}' ref_path =f'{base_path}/refs/main'withopen(ref_path, 'r') as f: ref = f.read().strip() snapshots_dir =f'{base_path}/snapshots/{ref}' safetensors_path =f'{snapshots_dir}/model.safetensors'import safetensors timer = ElapsedTimer() logging.debug(f'About to load safetensors from {safetensors_path}...')with timer: st = safetensors.safe_open( safetensors_path, framework="pt", device="cpu", ) msg = (f'Loaded safetensors from {safetensors_path} 'f'in {timer.elapsed:.4f} seconds.' ) logging.info(msg) config_path =f'{snapshots_dir}/config.json'withopen(config_path, 'r') as f: config = json.load(f) tokenizer_path =f'{snapshots_dir}/tokenizer.json'withopen(tokenizer_path, 'r') as f: tokenizer = json.load(f) tokenizer_config_path =f'{snapshots_dir}/tokenizer_config.json'withopen(tokenizer_config_path, 'r') as f: tokenizer_config = json.load(f) vocab_path =f'{snapshots_dir}/vocab.json'withopen(vocab_path, 'r') as f: vocab = json.load(f)return HuggingFaceModel( model_name, config, st, tokenizer, tokenizer_config, vocab, )
Now, I remember Andrej had a GPT.from_pretrained() routine that was geared toward loading the GPT2 models via the transformers Python package along the following lines:
from transformers import GPT2LMHeadModelhf_model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
The technique he used to prime is local GPT class from the larger model loaded from HuggingFace piqued my interest. I’ve reproduced the applicable code below, with some formatting tweaks only.
# create a from-scratch initialized minGPT modelconfig = GPTConfig(**config_args)model = GPT(config)sd = model.state_dict()sd_keys = sd.keys()# discard this mask / buffer, not a paramsd_keys = [ key for key in sd_keys ifnot key.endswith('.attn.bias')]# init a huggingface/transformers modelmodel_hf = GPT2LMHeadModel.from_pretrained(model_type)sd_hf = model_hf.state_dict()# copy while ensuring all of the parameters are aligned# and match in names and shapessd_keys_hf = sd_hf.keys()# ignore `.attn.masked_bias`; just a buffersd_keys_hf = [ key for key in sd_keys_hfifnot k.endswith('.attn.masked_bias')]# ditto; ignore `.attn.bias`, just the mask (buffer)sd_keys_hf = [ key for key in sd_keys_hfifnot key.endswith('.attn.bias')]transposed = ['attn.c_attn.weight','attn.c_proj.weight','mlp.c_fc.weight','mlp.c_proj.weight',]# basically the openai checkpoints use a "Conv1D" module,# but we only want to use a vanilla Linear, this means# that we have to transpose these weights when we import themassertlen(sd_keys_hf) ==len(sd_keys), (f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}")for k in sd_keys_hf:ifany(k.endswith(w) for w in transposed):# special treatment for the Conv1D weights# we need to transposeassert sd_hf[k].shape[::-1] == sd[k].shapewith torch.no_grad(): sd[k].copy_(sd_hf[k].t())else:# vanilla copy over the other parametersassert sd_hf[k].shape == sd[k].shapewith torch.no_grad(): sd[k].copy_(sd_hf[k])
Do all of that fiddling and et voilà, you’ve just primed your GPT model with the OpenAI weights!
The reason Andrej’s approach piqued my interest—in combination with poking around at a safetensors instance and realizing I could easily extract all 147-or-so tensors by name—was that it may allow us to copy the tensors from the model read from disk into our model in parallel, using multiple threads.
So I hacked together a variant of Andrej’s GPT.from_pretrained() that looked as follows:
class GPT: ...@classmethoddef from_pretrained(cls, model_name: str, map_location: Optional[str] =None, manual_seed: Optional[int] =None, torch_profile_activities: Optional[List[type]] =None, ) ->"GPT":""" Load a GPT model from a pretrained model. Arguments: model_name (str): Supplies the model name to use. See the docstring for `.util.get_huggingface_safetensors()` for more information about the format. map_location (str): Optionally supplies the device to map the loaded tensor parameters to. If None, "cuda" will be used if available, otherwise "cpu". manual_seed (int): Optionally supplies the manual seed to use for the model. If None, a random seed will be used. torch_profile_activities (list): Optionally supplies a list of torch.profiler.ProfilerActivity to profile. """if manual_seed isNone:# Use a random seed. manual_seed = random.randint(0, 2**32-1)if map_location isNone:if torch.cuda.is_available(): map_location ="cuda"else: map_location ="cpu" timer = ElapsedTimer()with timer: hf_model = get_huggingface_model(model_name) msg = (f'Loaded HuggingFace model {model_name} in 'f'{timer.elapsed:.3f} seconds.' ) logging.info(msg) config = GPTConfig(**{'block_size': hf_model.config['n_ctx'],'vocab_size': hf_model.config['vocab_size'],'n_layer': hf_model.config['n_layer'],'n_head': hf_model.config['n_head'],'n_embd': hf_model.config['n_embd'], }) checkpoint = GPTCheckpoint(**{'model': None,'step': 0,'val_loss': 0.0,'config': config, })with timer: model = cls( checkpoint=checkpoint, device=map_location, manual_seed=manual_seed, torch_profile_activities=torch_profile_activities, ) logging.info(f'Created GPT model in {timer.elapsed:.3f} seconds.')# This logic is based heavily off build-nanogpt's `train_gpt2.py`;# specifically: GPT.from_pretrained(). exclude = ('.attn.bias', '.attn.masked_bias', 'lm_head.weight') transpose = ('attn.c_attn.weight','attn.c_proj.weight','mlp.c_fc.weight','mlp.c_proj.weight', )# Identify the HuggingFace keys we're interested in. st = hf_model.safetensors# Identify our model keys we're interested in. sd = model.state_dict() sd_keys = [k for k in sd.keys() ifnot k.endswith(exclude)] hf_keys = [k.replace('transformer.', '') for k in sd_keys]# Copying tensors in parallel yields decent speedups,# at least on my V100s which have five concurrent copy# engines.1def copy_tensor(hf_key, sd_key): hf_tensor = st.get_tensor(hf_key)if hf_key.endswith(transpose):assert hf_tensor.shape[::-1] == sd[sd_key].shapewith torch.no_grad(): sd[sd_key].copy_(hf_tensor.t())else:assert hf_tensor.shape == sd[sd_key].shapewith torch.no_grad(): sd[sd_key].copy_(hf_tensor) keys =zip(hf_keys, sd_keys) max_workers =min(os.cpu_count(), len(sd_keys))with timer:2with ThreadPoolExecutor( max_workers=max_workers ) as executor: futures = { executor.submit(copy_tensor, hf_key, sd_key): (hf_key, sd_key)for (hf_key, sd_key) in keys }for future in as_completed(futures): future.result() logging.info(f'Copied weights with {max_workers} thread(s) 'f'in {timer.elapsed:.3f} seconds.' ) device = map_locationwith timer: model.to(device) msg =f'Moved model to {device} in {timer.elapsed:.3f} seconds.' logging.info(msg)return model
1
Define a copy_tensor function that is provided with the name of a tensor key as it appears in the HuggingFace model (hf_key) and as it appears in our state_dict model (sd_key). The sd state dict is accessible to all threads, so they can simply copy their tensor via sd[sd_key].copy_(hf_tensor).
2
Use a ThreadPoolExecutor() to dispatch the copy_tensor operations in parallel. In this case I’m using workers equivalent to the number of CPU cores, or number of tensors, whichever is fewer.
So that all worked pretty well, free-threading can absolutely be used to speed up things like loading tensors.
Additionally, if you play around with this locally, the gpt2-xl model will be available to select from the drop-down in the UI and also can be used via the command line (--model gpt2-xl) or in the REST /generate endpoint as a query parameter (/generate/foo...?model=gpt2-xl).
Multi-GPU Support
The final big change I introduced was multi-GPU support plus some very rudimentary round-robin-esque behavior that could be used from the generate() routines to obtain a reference to a model depending on the incoming user’s request.
The generate() routine now obtains a model by the following:
model = get_next_model()
The implementation for which is here, reproduced in part below:
NUM_GPUS = torch.cuda.device_count()# Add a CPU version at the end.TOTAL_MODELS = NUM_GPUS +1MODELS = [None] * TOTAL_MODELSMODELS_ROUND_ROBIN = itertools.cycle(range(TOTAL_MODELS))def get_next_model_random():# Randomly select a GPU to use.return MODELS[random.randint(0, TOTAL_MODELS -1)]def get_next_model_round_robin():with MODELS_LOCK: index =next(MODELS_ROUND_ROBIN)return MODELS[index]get_next_model = get_next_model_round_robindef load_models():global MODELS max_workers =min(TOTAL_MODELS, os.cpu_count()) timer = ElapsedTimer()with timer:with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit( GPT.from_local_pretrained, model_path=MODEL_CHECKPOINT, map_location=f'cuda:{i}', torch_profile_activities=TORCH_PROFILE_ACTIVITIES, ): i for i inrange(NUM_GPUS) }# Add the CPU model. futures[executor.submit( GPT.from_local_pretrained, model_path=MODEL_CHECKPOINT, map_location='cpu', torch_profile_activities=TORCH_PROFILE_ACTIVITIES, )] = NUM_GPUSfor future in as_completed(futures): i = futures[future] model = future.result() MODELS[i] = model msg = (f'Loaded model on {NUM_GPUS} GPU(s) and 1 CPU in 'f'{timer.elapsed:.3f} seconds.' ) logging.info(msg)PRETRAINED_MODELS = [None] * TOTAL_MODELSPRETRAINED_MODELS_ROUND_ROBIN = itertools.cycle(range(TOTAL_MODELS))def get_next_pretrained_model_random():# Randomly select a GPU to use.return PRETRAINED_MODELS[random.randint(0, TOTAL_MODELS -1)]def get_next_pretrained_model_round_robin():with PRETRAINED_MODELS_LOCK: index =next(PRETRAINED_MODELS_ROUND_ROBIN)return PRETRAINED_MODELS[index]get_next_pretrained_model = get_next_pretrained_model_round_robindef load_pretrained_models():global PRETRAINED_MODELS max_workers =min(TOTAL_MODELS, os.cpu_count()) timer = ElapsedTimer()with timer:with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit( GPT.from_pretrained, model_name='openai-community/gpt2-xl', map_location=f'cuda:{i}', torch_profile_activities=TORCH_PROFILE_ACTIVITIES, ): i for i inrange(NUM_GPUS) }# Add the CPU model. futures[executor.submit( GPT.from_pretrained, model_name='openai-community/gpt2-xl', map_location='cpu', torch_profile_activities=TORCH_PROFILE_ACTIVITIES, )] = NUM_GPUSfor future in as_completed(futures): i = futures[future] model = future.result() PRETRAINED_MODELS[i] = model msg = (f'Loaded gpt2-xl model on {NUM_GPUS} GPU(s) and 1 CPU in 'f'{timer.elapsed:.3f} seconds.' ) logging.info(msg)
Model Optimization
Two facilities are available to speed up your model in PyTorch: TorchScript, and Torch Dynamo, with the latter being a newer approach than the former.
In our parallelopedia.gpt2 module, our Gpt2App class has a classmethod named init_once(), which, as you can probably guess, gets called once by the HTTP server when starting up. This is where you stash expensive setup code like loading and compiling models.
The code looks similar to the following. We have two globals, TORCH_JIT_COMPILE and TORCH_DYNAMO_COMPILE that, when set, will attempt to optimize the models using the selected method.
class Gpt2App(HttpApp): ...@classmethoddef init_once(cls): load_models() load_pretrained_models()global MODELS, PRETRAINED_MODELS# This doesn't work because torch.jit doesn't handle our# async generator.global TRY_JIT_COMPILEif TRY_JIT_COMPILE:for (i, model) inenumerate(MODELS): model.config = dataclasses.asdict(model.config) timer = ElapsedTimer()with timer: model = torch.jit.script(model) MODELS[i] = model logging.info(f'JIT compiled model {i} in {timer.elapsed:.3f} seconds.' )global TRY_TORCH_COMPILEif TRY_TORCH_COMPILE:for (i, model) inenumerate(MODELS): model.config = dataclasses.asdict(model.config) timer = ElapsedTimer()with timer: model = torch.compile(model) MODELS[i] = model logging.info(f'torch.compiled model {i} in 'f'{timer.elapsed:.3f} seconds.' )for (i, model) inenumerate(PRETRAINED_MODELS): model.config = dataclasses.asdict(model.config) timer = ElapsedTimer()with timer: model = torch.compile(model) PRETRAINED_MODELS[i] = model logging.info(f'torch.compiled pretrained model {i} in 'f'{timer.elapsed:.3f} seconds.' )
TorchScript
TorchScript doesn’t work at all for our model—it balks on the async def generate_async_for() routine that is the workhorse of our asynchronous token-by-token generation.
And that ended my TorchScript experiment :-)
Torch Dynamo (torch.compile)
Torch Dynamo is a new feature that was introduced by PyTorch 2.0 that hooks into the Python interpreter and traces model execution and then builds optimized kernels based on the information observed during runtime tracing.
When it works, it works really well, and you can get significant speedups both in training and inference with literally a single line:
model = torch.compile(model)
The first problem we hit with Dynamo is that it’s explicitly not supported by PyTorch on free-threaded builds:
>>> model = torch.compile(model)---------------------------------------------------------------------------RuntimeError Traceback (most recent call last)Cell In[3], line 1---->1 model = torch.compile(model)File ~/mambaforge/envs/py313t/lib/python3.13t/site-packages/torch/__init__.py:2526, incompile(model, fullgraph, dynamic, backend, mode, options, disable)2524raiseRuntimeError("torch.compile is not supported on Python 3.14+")2525elif sysconfig.get_config_var("Py_GIL_DISABLED") ==1:->2526raiseRuntimeError(2527"torch.compile is not supported on Python built with GIL disabled"2528 )2530# Decorator mode2531if model isNone:RuntimeError: torch.compileisnot supported on Python built with GIL disabled
Let’s hack that torch/__init__.py file as follows and try again.
--- __init__.py.orig 2025-02-09 13:28:27.892979258 -0800+++ __init__.py 2025-02-09 13:30:13.879909529 -0800@@ -2523,9 +2523,7 @@ if sys.version_info >= (3, 14): raise RuntimeError("torch.compile is not supported on Python 3.14+") elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1:- raise RuntimeError(- "torch.compile is not supported on Python built with GIL disabled"- )+ print("torch.__init__: Ignoring unsupported torch.compile() with no GIL unsupported") # Decorator mode if model is None:
>>> model = torch.compile(model)torch.__init__: Ignoring unsupported torch.compile() with no GIL unsupported---------------------------------------------------------------------------RuntimeError Traceback (most recent call last)Cell In[3], line 1---->1 model = torch.compile(model)File ~/mambaforge/envs/py313t/lib/python3.13t/site-packages/torch/__init__.py:2563, incompile(model, fullgraph, dynamic, backend, mode, options, disable)2560else:2561 backend = _TorchCompileWrapper(backend, mode, options, dynamic)->2563return torch._dynamo.optimize(2564 backend=backend,2565 nopython=fullgraph,2566 dynamic=dynamic,2567 disable=disable,2568 )(model)File ~/mambaforge/envs/py313t/lib/python3.13t/site-packages/torch/_dynamo/eval_frame.py:842, in optimize(*args, **kwargs)839 kwargs["nopython"] = ca_kwargs_override["fullgraph"]840return optimize(*args, **kwargs)-->842return _optimize(rebuild_ctx, *args, **kwargs)File ~/mambaforge/envs/py313t/lib/python3.13t/site-packages/torch/_dynamo/eval_frame.py:881, in _optimize(rebuild_ctx, backend, nopython, guard_export_fn, guard_fail_fn, disable, dynamic)845def _optimize(846 rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]],847 backend="inductor", (...)853 dynamic=None,854 ) -> Union[OptimizeContext, _NullDecorator]:855""" 856 The main entrypoint of TorchDynamo. Do graph capture and call 857 backend() to optimize extracted graphs. (...) 879 ... 880 """-->881 check_if_dynamo_supported()882# Note: The hooks object could be global instead of passed around, *however* that would make883# for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.884# There is some prior art around this, w/r/t nesting backend calls are enforced to be the same885# compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an886# easier to understand UX at the cost of a little more plumbing on our end.887 hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn)File ~/mambaforge/envs/py313t/lib/python3.13t/site-packages/torch/_dynamo/eval_frame.py:805, in check_if_dynamo_supported()803raiseRuntimeError("Python 3.14+ not yet supported for torch.compile")804elif sysconfig.get_config_var("Py_GIL_DISABLED") ==1:-->805raiseRuntimeError(806"torch.compile is not supported on Python built with GIL disabled"807 )RuntimeError: torch.compileisnot supported on Python built with GIL disabled
Let’s hack torch/_dynamo/eval_frame.py too:
--- eval_frame.py.orig 2025-02-09 13:32:18.266470283 -0800+++ eval_frame.py 2025-02-09 13:32:32.746291774 -0800@@ -802,9 +802,7 @@ if sys.version_info >= (3, 14): raise RuntimeError("Python 3.14+ not yet supported for torch.compile") elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1:- raise RuntimeError(- "torch.compile is not supported on Python built with GIL disabled"- )+ print("torch._dynamo.eval_frame: Ignoring unsupported torch.compile() with no GIL unsupported") def is_dynamo_supported():
Now let’s try again:
>>> model = torch.compile(model)torch.__init__: Ignoring unsupported torch.compile() with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile() with no GIL unsupported>>>
So, we can forcibly coerce PyTorch Dynamo to compile our model even if we’re in free-threaded mode.
But does it work? And is it faster? Let’s investigate.
We can invoke our parallelopedia.gpt2 module directly with various command line arguments to test out generation performance. The accompanying --help is furnished below for reference:
% python -m parallelopedia.gpt2 --helptorch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedusage: gpt2.py [-h] [--log-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}] [--model {gpt2,gpt2-xl}] [--device DEVICE][--max-length MAX_LENGTH] [--top-k TOP_K] [--seed SEED] [--prompt PROMPT] [--torch-compile][--torch-jit][--torch-compile-fullgraph][--torch-compile-reduce-overhead][--torch-compile-max-autotune][--generate-slim] [--rounds ROUNDS] [--wrap WRAP] [--note NOTE]Run the GPT2 module.options:-h,--help show this help message and exit--log-level{DEBUG,INFO,WARNING,ERROR,CRITICAL}Set the logging level.--model{gpt2,gpt2-xl}Select the model to use.--device DEVICE Select the device to use.--max-length MAX_LENGTHSet the maximum length of the generated text.--top-k TOP_K Set the top-k value for sampling.--seed SEED Set the random seed for generation.--prompt PROMPT Set the prompt for generation.--torch-compile Compile the models using torch.compile().--torch-jit Compile the models using torch.jit.script().--torch-compile-fullgraphCompile the models using torch.compile()with fullgraph=True.--torch-compile-reduce-overheadCompile the models using torch.compile()with mode="reduce-overhead"--torch-compile-max-autotuneCompile the models using torch.compile()with mode="max_autotune".--generate-slim Use the generate_slim()method for generation.--rounds ROUNDS Set the number of rounds for generation.--wrap WRAP Set the wrap width for text output.--note NOTE Set a note to include in the JSON output.
The full source code for the module’s main() function follows. Based on the command line parameters furnished, we can test various generation variants, such as no torch.compile(), torch.compile() and nothing else, and more advanced permutations such as the following:
def main():""" Main entry point for the parallelopedia.gpt2 module. """ args = parse_arguments() logging.basicConfig( level=getattr(logging, args.log_level),format='%(asctime)s - %(levelname)s - %(message)s', ) start_time = time.time() start_timestamp = datetime.datetime.now().isoformat() timer = ElapsedTimer()with timer:if args.model =='gpt2-xl': model = GPT.from_pretrained( model_name='openai-community/gpt2-xl', map_location=args.device, )else: model = GPT.from_local_pretrained( model_path=MODEL_CHECKPOINT, map_location=args.device, manual_seed=args.seed, ) logging.info(f'Loaded {args.model} on {args.device} 'f'in {timer.elapsed:.3f} seconds.' )if args.torch_compile:if args.torch_jit: msg ='Cannot specify both --torch-compile and --torch-jit.'raiseValueError(msg) model.config = dataclasses.asdict(model.config) kwds = {}if args.torch_compile_fullgraph: kwds['fullgraph'] =Trueif args.torch_compile_reduce_overhead:if args.torch_compile_max_autotune: msg = ('Cannot specify both --torch-compile-reduce-overhead ''and --torch-compile-max-autotune.' )raiseValueError(msg) kwds['mode'] ='reduce-overhead'elif args.torch_compile_max_autotune: kwds['mode'] ='max-autotune'with timer: model = torch.compile(model, **kwds) msg =f'torch.compiled model in {timer.elapsed:.3f} seconds.' logging.info(msg)elif args.torch_jit: model.config = dataclasses.asdict(model.config)with timer: model = torch.jit.script(model) msg =f'JIT compiled model in {timer.elapsed:.3f} seconds.') logging.info(msg) seed = args.seedif seed isNoneor seed =='': seed = random.randint(0, 2**32-1)if args.generate_slim: text_tokens = model.enc.encode(args.prompt) prompt_token_length =len(text_tokens) rates = []for i inrange(args.rounds): logging.info(f'Round {i +1} of {args.rounds}.')if args.generate_slim:with timer: x = model.generate_slim( text_tokens, max_length=args.max_length, top_k=args.top_k, seed=seed, ) elapsed = timer.elapsed count = x.size(1) - prompt_token_length tokens_per_sec = count / elapsed rates.append(tokens_per_sec) logging.info(f'Generated {count} tokens in {elapsed:.2f} seconds 'f'({tokens_per_sec:.2f} tokens/sec)' ) output = model.enc.decode(x[0].tolist())else: save_rate =lambda x: rates.append(x) output = model.generate( args.prompt, max_length=args.max_length, top_k=args.top_k, seed=seed, save_rate=save_rate, )if args.wrap: output = textwrap.fill(output, width=args.wrap) logging.info(f'Output:\n{output}')# The filename is of the form:# `gpt2-rates-<yyyy-mm-dd-hh-ss-mm.sss>-[optional].json` now = datetime.datetime.now() timestamp = now.strftime('%Y-%m-%d-%H-%M-%S-%f') filename =f"gpt2-rates-{timestamp}"if args.torch_compile: filename +='-torch-compile'if args.torch_compile_reduce_overhead: filename +='-reduce-overhead'elif args.torch_compile_max_autotune: filename +='-max-autotune'if args.torch_compile_fullgraph: filename +='-fullgraph'if args.generate_slim: filename +='-generate-slim' conda_env_name = os.environ.get('CONDA_DEFAULT_ENV', 'Unknown') filename +=f'-{conda_env_name}' filename +='.json'ifnotisinstance(model.config, dict): model_config = dataclasses.asdict(model.config)else: model_config = model.config end_time = time.time() end_timestamp = datetime.datetime.now().isoformat()if args.device.startswith('cuda'): ix = args.device.find(':')if ix ==-1: device_index =0else: device_index =int(args.device[ix+1:]) device_name = torch.cuda.get_device_name(device_index)else: device_name ='CPU'try: is_gil_enabled = sys._is_gil_enabled()exceptAttributeError: is_gil_enabled =False# Prepare a dictionary with the details to save. run_details = {"rates": rates,"model_config": model_config,"args": vars(args),"start_timestamp": start_timestamp,"end_timestamp": end_timestamp,"elapsed": f'{end_time - start_time:.3f}',"device_name": device_name,"conda_env_name": conda_env_name,"is_gil_enabled": is_gil_enabled,"note": args.note, }# Write the JSON file.withopen(filename, "w") as json_file: json.dump(run_details, json_file, indent=4) logging.info(f"Run details saved to {filename}.")
Let’s take a look at whether or not torch.compile() improves performance in our py313t free-threaded environment, first.
Performance Comparison
I wrote a bash script run-py313t-gpt2-compile-tests.sh, reproduced below, that ran various permutations of generation with different torch.compile() options.
#!/bin/bash# Ensure our environment name is `py313t`.if["$CONDA_DEFAULT_ENV"!="py313t"];thenecho"Error: Conda environment is not 'py313t'."exit 1fi# Ensure PARALLELOPEDIA_ROOT is set.if[-z"$PARALLELOPEDIA_ROOT"];thenecho"Error: PARALLELOPEDIA_ROOT is not set."exit 1fiSEED=42DEVICE="cuda:3"ROUNDS=20OPTIONS=("--torch-compile""--torch-compile --torch-compile-fullgraph""--torch-compile --torch-compile-reduce-overhead""--torch-compile --torch-compile-reduce-overhead --torch-compile-fullgraph""--torch-compile --torch-compile-reduce-overhead --torch-compile-fullgraph""--torch-compile --torch-compile-max-autotune""--torch-compile --torch-compile-max-autotune --torch-compile-fullgraph")echo"GPT.generate() variants"time python -Xgil=0 -m parallelopedia.gpt2 \--seed=$SEED\--rounds=$ROUNDS\--device=$DEVICEfor opt in"${OPTIONS[@]}";do# Split opt into separate arguments.eval set --$opttime python -Xgil=0 -m parallelopedia.gpt2 \--seed=$SEED\--rounds=$ROUNDS\--device=$DEVICE\"$@"done
The full log file for the run is captured in the callout below.
run-py313t-gpt2-compile-tests.log
GPT.generate()variants2025-02-10 20:47:17,958 - INFO - Loaded /mnt/raid1/trent/src/parallelopedia/data/model_19072.pt checkpoint in 0.463 seconds.2025-02-10 20:47:17,965 - INFO - Initialized GPT model in 0.007 seconds.2025-02-10 20:47:18,191 - INFO - Loaded model weights in 0.226 seconds.<frozen importlib._bootstrap>:488: RuntimeWarning: The global interpreter lock (GIL)has been enabled to load module 'triton._C.libtriton', which has not declared that it can run safely without the GIL. To override this behavior and keep the GIL disabled (at your own risk), run with PYTHON_GIL=0 or -Xgil=0.2025-02-10 20:47:18,558 - INFO - Created GPT model in 0.600 seconds.2025-02-10 20:47:18,642 - INFO - Moved model to cuda:3 in 0.083 seconds.2025-02-10 20:47:18,642 - INFO - Loaded model from step 19072, val_loss 3.05197024345397952025-02-10 20:47:18,642 - INFO - Loaded gpt2 on cuda:3 in 1.148 seconds.2025-02-10 20:47:18,642 - INFO - Round 1 of 20.2025-02-10 20:47:19,454 - INFO - Generated 91 tokens in 0.81 seconds (112.32 tokens/sec)2025-02-10 20:47:19,454 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:19,454 - INFO - Round 2 of 20.2025-02-10 20:47:20,091 - INFO - Generated 91 tokens in 0.64 seconds (142.89 tokens/sec)2025-02-10 20:47:20,092 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:20,092 - INFO - Round 3 of 20.2025-02-10 20:47:20,731 - INFO - Generated 91 tokens in 0.64 seconds (142.49 tokens/sec)2025-02-10 20:47:20,731 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:20,731 - INFO - Round 4 of 20.2025-02-10 20:47:21,366 - INFO - Generated 91 tokens in 0.63 seconds (143.47 tokens/sec)2025-02-10 20:47:21,366 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:21,366 - INFO - Round 5 of 20.2025-02-10 20:47:22,000 - INFO - Generated 91 tokens in 0.63 seconds (143.62 tokens/sec)2025-02-10 20:47:22,000 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:22,000 - INFO - Round 6 of 20.2025-02-10 20:47:22,634 - INFO - Generated 91 tokens in 0.63 seconds (143.62 tokens/sec)2025-02-10 20:47:22,635 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:22,635 - INFO - Round 7 of 20.2025-02-10 20:47:23,270 - INFO - Generated 91 tokens in 0.64 seconds (143.25 tokens/sec)2025-02-10 20:47:23,271 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:23,271 - INFO - Round 8 of 20.2025-02-10 20:47:23,904 - INFO - Generated 91 tokens in 0.63 seconds (143.67 tokens/sec)2025-02-10 20:47:23,905 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:23,905 - INFO - Round 9 of 20.2025-02-10 20:47:24,539 - INFO - Generated 91 tokens in 0.63 seconds (143.56 tokens/sec)2025-02-10 20:47:24,539 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:24,539 - INFO - Round 10 of 20.2025-02-10 20:47:25,174 - INFO - Generated 91 tokens in 0.63 seconds (143.56 tokens/sec)2025-02-10 20:47:25,174 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:25,174 - INFO - Round 11 of 20.2025-02-10 20:47:25,812 - INFO - Generated 91 tokens in 0.64 seconds (142.78 tokens/sec)2025-02-10 20:47:25,812 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:25,812 - INFO - Round 12 of 20.2025-02-10 20:47:26,444 - INFO - Generated 91 tokens in 0.63 seconds (143.92 tokens/sec)2025-02-10 20:47:26,445 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:26,445 - INFO - Round 13 of 20.2025-02-10 20:47:27,080 - INFO - Generated 91 tokens in 0.63 seconds (143.34 tokens/sec)2025-02-10 20:47:27,080 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:27,081 - INFO - Round 14 of 20.2025-02-10 20:47:27,716 - INFO - Generated 91 tokens in 0.63 seconds (143.34 tokens/sec)2025-02-10 20:47:27,716 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:27,716 - INFO - Round 15 of 20.2025-02-10 20:47:28,350 - INFO - Generated 91 tokens in 0.63 seconds (143.65 tokens/sec)2025-02-10 20:47:28,350 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:28,350 - INFO - Round 16 of 20.2025-02-10 20:47:28,984 - INFO - Generated 91 tokens in 0.63 seconds (143.73 tokens/sec)2025-02-10 20:47:28,984 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:28,984 - INFO - Round 17 of 20.2025-02-10 20:47:29,617 - INFO - Generated 91 tokens in 0.63 seconds (143.79 tokens/sec)2025-02-10 20:47:29,618 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:29,618 - INFO - Round 18 of 20.2025-02-10 20:47:30,260 - INFO - Generated 91 tokens in 0.64 seconds (141.71 tokens/sec)2025-02-10 20:47:30,261 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:30,261 - INFO - Round 19 of 20.2025-02-10 20:47:30,896 - INFO - Generated 91 tokens in 0.63 seconds (143.36 tokens/sec)2025-02-10 20:47:30,896 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:30,896 - INFO - Round 20 of 20.2025-02-10 20:47:31,530 - INFO - Generated 91 tokens in 0.63 seconds (143.55 tokens/sec)2025-02-10 20:47:31,531 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:31,531 - INFO - Run details saved to gpt2-rates-2025-02-10-20-47-31-531204-py313t.json.torch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedreal 0m19.736suser 0m22.944ssys 0m1.382sException ignored in: <_io.BufferedWriter name=41>BrokenPipeError: [Errno 32] Broken pipe2025-02-10 20:47:37,663 - INFO - Loaded /mnt/raid1/trent/src/parallelopedia/data/model_19072.pt checkpoint in 0.484 seconds.2025-02-10 20:47:37,670 - INFO - Initialized GPT model in 0.006 seconds.2025-02-10 20:47:37,896 - INFO - Loaded model weights in 0.226 seconds.<frozen importlib._bootstrap>:488: RuntimeWarning: The global interpreter lock (GIL)has been enabled to load module 'triton._C.libtriton', which has not declared that it can run safely without the GIL. To override this behavior and keep the GIL disabled (at your own risk), run with PYTHON_GIL=0 or -Xgil=0.2025-02-10 20:47:38,263 - INFO - Created GPT model in 0.600 seconds.2025-02-10 20:47:38,347 - INFO - Moved model to cuda:3 in 0.083 seconds.2025-02-10 20:47:38,347 - INFO - Loaded model from step 19072, val_loss 3.05197024345397952025-02-10 20:47:38,347 - INFO - Loaded gpt2 on cuda:3 in 1.168 seconds.2025-02-10 20:47:38,349 - INFO - torch.compiled model in 0.002 seconds.2025-02-10 20:47:38,349 - INFO - Round 1 of 20.2025-02-10 20:47:39,174 - INFO - Generated 91 tokens in 0.82 seconds (110.47 tokens/sec)2025-02-10 20:47:39,174 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:39,174 - INFO - Round 2 of 20.2025-02-10 20:47:39,828 - INFO - Generated 91 tokens in 0.65 seconds (139.24 tokens/sec)2025-02-10 20:47:39,829 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:39,829 - INFO - Round 3 of 20.2025-02-10 20:47:40,483 - INFO - Generated 91 tokens in 0.65 seconds (139.04 tokens/sec)2025-02-10 20:47:40,484 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:40,484 - INFO - Round 4 of 20.2025-02-10 20:47:41,138 - INFO - Generated 91 tokens in 0.65 seconds (139.23 tokens/sec)2025-02-10 20:47:41,138 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:41,138 - INFO - Round 5 of 20.2025-02-10 20:47:41,791 - INFO - Generated 91 tokens in 0.65 seconds (139.47 tokens/sec)2025-02-10 20:47:41,791 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:41,791 - INFO - Round 6 of 20.2025-02-10 20:47:42,445 - INFO - Generated 91 tokens in 0.65 seconds (139.29 tokens/sec)2025-02-10 20:47:42,445 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:42,445 - INFO - Round 7 of 20.2025-02-10 20:47:43,099 - INFO - Generated 91 tokens in 0.65 seconds (139.09 tokens/sec)2025-02-10 20:47:43,100 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:43,100 - INFO - Round 8 of 20.2025-02-10 20:47:43,755 - INFO - Generated 91 tokens in 0.66 seconds (138.85 tokens/sec)2025-02-10 20:47:43,756 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:43,756 - INFO - Round 9 of 20.2025-02-10 20:47:44,409 - INFO - Generated 91 tokens in 0.65 seconds (139.29 tokens/sec)2025-02-10 20:47:44,410 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:44,410 - INFO - Round 10 of 20.2025-02-10 20:47:45,063 - INFO - Generated 91 tokens in 0.65 seconds (139.29 tokens/sec)2025-02-10 20:47:45,064 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:45,064 - INFO - Round 11 of 20.2025-02-10 20:47:45,718 - INFO - Generated 91 tokens in 0.65 seconds (139.11 tokens/sec)2025-02-10 20:47:45,718 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:45,719 - INFO - Round 12 of 20.2025-02-10 20:47:46,371 - INFO - Generated 91 tokens in 0.65 seconds (139.47 tokens/sec)2025-02-10 20:47:46,372 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:46,372 - INFO - Round 13 of 20.2025-02-10 20:47:47,046 - INFO - Generated 91 tokens in 0.67 seconds (135.03 tokens/sec)2025-02-10 20:47:47,046 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:47,046 - INFO - Round 14 of 20.2025-02-10 20:47:47,692 - INFO - Generated 91 tokens in 0.65 seconds (140.90 tokens/sec)2025-02-10 20:47:47,693 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:47,693 - INFO - Round 15 of 20.2025-02-10 20:47:48,339 - INFO - Generated 91 tokens in 0.65 seconds (140.84 tokens/sec)2025-02-10 20:47:48,339 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:48,340 - INFO - Round 16 of 20.2025-02-10 20:47:48,986 - INFO - Generated 91 tokens in 0.65 seconds (140.90 tokens/sec)2025-02-10 20:47:48,986 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:48,986 - INFO - Round 17 of 20.2025-02-10 20:47:49,634 - INFO - Generated 91 tokens in 0.65 seconds (140.40 tokens/sec)2025-02-10 20:47:49,635 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:49,635 - INFO - Round 18 of 20.2025-02-10 20:47:50,282 - INFO - Generated 91 tokens in 0.65 seconds (140.63 tokens/sec)2025-02-10 20:47:50,283 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:50,283 - INFO - Round 19 of 20.2025-02-10 20:47:50,929 - INFO - Generated 91 tokens in 0.65 seconds (140.73 tokens/sec)2025-02-10 20:47:50,930 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:50,930 - INFO - Round 20 of 20.2025-02-10 20:47:51,578 - INFO - Generated 91 tokens in 0.65 seconds (140.41 tokens/sec)2025-02-10 20:47:51,579 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:51,579 - INFO - Run details saved to gpt2-rates-2025-02-10-20-47-51-579116-torch-compile-py313t.json.torch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedtorch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedreal 0m20.073suser 0m23.400ssys 0m1.313sException ignored in: <_io.BufferedWriter name=41>BrokenPipeError: [Errno 32] Broken pipe2025-02-10 20:47:57,975 - INFO - Loaded /mnt/raid1/trent/src/parallelopedia/data/model_19072.pt checkpoint in 0.486 seconds.2025-02-10 20:47:57,982 - INFO - Initialized GPT model in 0.006 seconds.2025-02-10 20:47:58,209 - INFO - Loaded model weights in 0.227 seconds.<frozen importlib._bootstrap>:488: RuntimeWarning: The global interpreter lock (GIL)has been enabled to load module 'triton._C.libtriton', which has not declared that it can run safely without the GIL. To override this behavior and keep the GIL disabled (at your own risk), run with PYTHON_GIL=0 or -Xgil=0.2025-02-10 20:47:58,575 - INFO - Created GPT model in 0.599 seconds.2025-02-10 20:47:58,659 - INFO - Moved model to cuda:3 in 0.084 seconds.2025-02-10 20:47:58,659 - INFO - Loaded model from step 19072, val_loss 3.05197024345397952025-02-10 20:47:58,659 - INFO - Loaded gpt2 on cuda:3 in 1.170 seconds.2025-02-10 20:47:58,660 - INFO - torch.compiled model in 0.001 seconds.2025-02-10 20:47:58,661 - INFO - Round 1 of 20.2025-02-10 20:47:59,483 - INFO - Generated 91 tokens in 0.82 seconds (110.79 tokens/sec)2025-02-10 20:47:59,483 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:47:59,483 - INFO - Round 2 of 20.2025-02-10 20:48:00,153 - INFO - Generated 91 tokens in 0.67 seconds (135.92 tokens/sec)2025-02-10 20:48:00,153 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:00,153 - INFO - Round 3 of 20.2025-02-10 20:48:00,796 - INFO - Generated 91 tokens in 0.64 seconds (141.66 tokens/sec)2025-02-10 20:48:00,796 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:00,796 - INFO - Round 4 of 20.2025-02-10 20:48:01,439 - INFO - Generated 91 tokens in 0.64 seconds (141.69 tokens/sec)2025-02-10 20:48:01,439 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:01,439 - INFO - Round 5 of 20.2025-02-10 20:48:02,086 - INFO - Generated 91 tokens in 0.65 seconds (140.83 tokens/sec)2025-02-10 20:48:02,086 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:02,086 - INFO - Round 6 of 20.2025-02-10 20:48:02,730 - INFO - Generated 91 tokens in 0.64 seconds (141.33 tokens/sec)2025-02-10 20:48:02,731 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:02,731 - INFO - Round 7 of 20.2025-02-10 20:48:03,374 - INFO - Generated 91 tokens in 0.64 seconds (141.45 tokens/sec)2025-02-10 20:48:03,375 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:03,375 - INFO - Round 8 of 20.2025-02-10 20:48:04,017 - INFO - Generated 91 tokens in 0.64 seconds (141.76 tokens/sec)2025-02-10 20:48:04,017 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:04,017 - INFO - Round 9 of 20.2025-02-10 20:48:04,661 - INFO - Generated 91 tokens in 0.64 seconds (141.46 tokens/sec)2025-02-10 20:48:04,661 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:04,661 - INFO - Round 10 of 20.2025-02-10 20:48:05,303 - INFO - Generated 91 tokens in 0.64 seconds (141.82 tokens/sec)2025-02-10 20:48:05,303 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:05,303 - INFO - Round 11 of 20.2025-02-10 20:48:05,945 - INFO - Generated 91 tokens in 0.64 seconds (141.91 tokens/sec)2025-02-10 20:48:05,945 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:05,945 - INFO - Round 12 of 20.2025-02-10 20:48:06,587 - INFO - Generated 91 tokens in 0.64 seconds (141.83 tokens/sec)2025-02-10 20:48:06,587 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:06,587 - INFO - Round 13 of 20.2025-02-10 20:48:07,229 - INFO - Generated 91 tokens in 0.64 seconds (141.93 tokens/sec)2025-02-10 20:48:07,229 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:07,229 - INFO - Round 14 of 20.2025-02-10 20:48:07,871 - INFO - Generated 91 tokens in 0.64 seconds (141.88 tokens/sec)2025-02-10 20:48:07,871 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:07,871 - INFO - Round 15 of 20.2025-02-10 20:48:08,513 - INFO - Generated 91 tokens in 0.64 seconds (141.74 tokens/sec)2025-02-10 20:48:08,514 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:08,514 - INFO - Round 16 of 20.2025-02-10 20:48:09,157 - INFO - Generated 91 tokens in 0.64 seconds (141.58 tokens/sec)2025-02-10 20:48:09,157 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:09,157 - INFO - Round 17 of 20.2025-02-10 20:48:09,799 - INFO - Generated 91 tokens in 0.64 seconds (141.79 tokens/sec)2025-02-10 20:48:09,800 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:09,800 - INFO - Round 18 of 20.2025-02-10 20:48:10,442 - INFO - Generated 91 tokens in 0.64 seconds (141.75 tokens/sec)2025-02-10 20:48:10,442 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:10,442 - INFO - Round 19 of 20.2025-02-10 20:48:11,084 - INFO - Generated 91 tokens in 0.64 seconds (141.77 tokens/sec)2025-02-10 20:48:11,085 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:11,085 - INFO - Round 20 of 20.2025-02-10 20:48:11,727 - INFO - Generated 91 tokens in 0.64 seconds (141.75 tokens/sec)2025-02-10 20:48:11,727 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:11,728 - INFO - Run details saved to gpt2-rates-2025-02-10-20-48-11-727901-torch-compile-fullgraph-py313t.json.torch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedtorch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedreal 0m20.125suser 0m23.218ssys 0m1.382sException ignored in: <_io.BufferedWriter name=41>BrokenPipeError: [Errno 32] Broken pipe2025-02-10 20:48:17,900 - INFO - Loaded /mnt/raid1/trent/src/parallelopedia/data/model_19072.pt checkpoint in 0.486 seconds.2025-02-10 20:48:17,907 - INFO - Initialized GPT model in 0.006 seconds.2025-02-10 20:48:18,132 - INFO - Loaded model weights in 0.225 seconds.<frozen importlib._bootstrap>:488: RuntimeWarning: The global interpreter lock (GIL)has been enabled to load module 'triton._C.libtriton', which has not declared that it can run safely without the GIL. To override this behavior and keep the GIL disabled (at your own risk), run with PYTHON_GIL=0 or -Xgil=0.2025-02-10 20:48:18,494 - INFO - Created GPT model in 0.594 seconds.2025-02-10 20:48:18,577 - INFO - Moved model to cuda:3 in 0.082 seconds.2025-02-10 20:48:18,577 - INFO - Loaded model from step 19072, val_loss 3.05197024345397952025-02-10 20:48:18,577 - INFO - Loaded gpt2 on cuda:3 in 1.162 seconds.2025-02-10 20:48:18,581 - INFO - torch.compiled model in 0.004 seconds.2025-02-10 20:48:18,581 - INFO - Round 1 of 20.2025-02-10 20:48:19,397 - INFO - Generated 91 tokens in 0.82 seconds (111.57 tokens/sec)2025-02-10 20:48:19,398 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:19,398 - INFO - Round 2 of 20.2025-02-10 20:48:20,044 - INFO - Generated 91 tokens in 0.65 seconds (140.99 tokens/sec)2025-02-10 20:48:20,044 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:20,044 - INFO - Round 3 of 20.2025-02-10 20:48:20,690 - INFO - Generated 91 tokens in 0.65 seconds (140.95 tokens/sec)2025-02-10 20:48:20,690 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:20,690 - INFO - Round 4 of 20.2025-02-10 20:48:21,335 - INFO - Generated 91 tokens in 0.64 seconds (141.16 tokens/sec)2025-02-10 20:48:21,336 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:21,336 - INFO - Round 5 of 20.2025-02-10 20:48:21,981 - INFO - Generated 91 tokens in 0.64 seconds (141.10 tokens/sec)2025-02-10 20:48:21,981 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:21,981 - INFO - Round 6 of 20.2025-02-10 20:48:22,627 - INFO - Generated 91 tokens in 0.65 seconds (141.03 tokens/sec)2025-02-10 20:48:22,627 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:22,627 - INFO - Round 7 of 20.2025-02-10 20:48:23,274 - INFO - Generated 91 tokens in 0.65 seconds (140.67 tokens/sec)2025-02-10 20:48:23,275 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:23,275 - INFO - Round 8 of 20.2025-02-10 20:48:23,950 - INFO - Generated 91 tokens in 0.68 seconds (134.77 tokens/sec)2025-02-10 20:48:23,950 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:23,951 - INFO - Round 9 of 20.2025-02-10 20:48:24,637 - INFO - Generated 91 tokens in 0.69 seconds (132.66 tokens/sec)2025-02-10 20:48:24,637 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:24,637 - INFO - Round 10 of 20.2025-02-10 20:48:25,293 - INFO - Generated 91 tokens in 0.66 seconds (138.78 tokens/sec)2025-02-10 20:48:25,294 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:25,294 - INFO - Round 11 of 20.2025-02-10 20:48:25,952 - INFO - Generated 91 tokens in 0.66 seconds (138.35 tokens/sec)2025-02-10 20:48:25,952 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:25,952 - INFO - Round 12 of 20.2025-02-10 20:48:26,609 - INFO - Generated 91 tokens in 0.66 seconds (138.65 tokens/sec)2025-02-10 20:48:26,609 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:26,609 - INFO - Round 13 of 20.2025-02-10 20:48:27,265 - INFO - Generated 91 tokens in 0.66 seconds (138.73 tokens/sec)2025-02-10 20:48:27,265 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:27,265 - INFO - Round 14 of 20.2025-02-10 20:48:27,922 - INFO - Generated 91 tokens in 0.66 seconds (138.71 tokens/sec)2025-02-10 20:48:27,922 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:27,922 - INFO - Round 15 of 20.2025-02-10 20:48:28,579 - INFO - Generated 91 tokens in 0.66 seconds (138.54 tokens/sec)2025-02-10 20:48:28,580 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:28,580 - INFO - Round 16 of 20.2025-02-10 20:48:29,236 - INFO - Generated 91 tokens in 0.66 seconds (138.70 tokens/sec)2025-02-10 20:48:29,236 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:29,236 - INFO - Round 17 of 20.2025-02-10 20:48:29,894 - INFO - Generated 91 tokens in 0.66 seconds (138.52 tokens/sec)2025-02-10 20:48:29,894 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:29,894 - INFO - Round 18 of 20.2025-02-10 20:48:30,550 - INFO - Generated 91 tokens in 0.66 seconds (138.78 tokens/sec)2025-02-10 20:48:30,550 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:30,550 - INFO - Round 19 of 20.2025-02-10 20:48:31,208 - INFO - Generated 91 tokens in 0.66 seconds (138.36 tokens/sec)2025-02-10 20:48:31,209 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:31,209 - INFO - Round 20 of 20.2025-02-10 20:48:31,869 - INFO - Generated 91 tokens in 0.66 seconds (137.96 tokens/sec)2025-02-10 20:48:31,869 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:31,869 - INFO - Run details saved to gpt2-rates-2025-02-10-20-48-31-869451-torch-compile-reduce-overhead-py313t.json.torch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedtorch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedreal 0m20.138suser 0m23.350ssys 0m1.425sException ignored in: <_io.BufferedWriter name=41>BrokenPipeError: [Errno 32] Broken pipe2025-02-10 20:48:38,069 - INFO - Loaded /mnt/raid1/trent/src/parallelopedia/data/model_19072.pt checkpoint in 0.481 seconds.2025-02-10 20:48:38,076 - INFO - Initialized GPT model in 0.006 seconds.2025-02-10 20:48:38,305 - INFO - Loaded model weights in 0.228 seconds.<frozen importlib._bootstrap>:488: RuntimeWarning: The global interpreter lock (GIL)has been enabled to load module 'triton._C.libtriton', which has not declared that it can run safely without the GIL. To override this behavior and keep the GIL disabled (at your own risk), run with PYTHON_GIL=0 or -Xgil=0.2025-02-10 20:48:38,672 - INFO - Created GPT model in 0.602 seconds.2025-02-10 20:48:38,756 - INFO - Moved model to cuda:3 in 0.085 seconds.2025-02-10 20:48:38,756 - INFO - Loaded model from step 19072, val_loss 3.05197024345397952025-02-10 20:48:38,757 - INFO - Loaded gpt2 on cuda:3 in 1.169 seconds.2025-02-10 20:48:38,759 - INFO - torch.compiled model in 0.002 seconds.2025-02-10 20:48:38,759 - INFO - Round 1 of 20.2025-02-10 20:48:39,584 - INFO - Generated 91 tokens in 0.82 seconds (110.43 tokens/sec)2025-02-10 20:48:39,584 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:39,584 - INFO - Round 2 of 20.2025-02-10 20:48:40,223 - INFO - Generated 91 tokens in 0.64 seconds (142.66 tokens/sec)2025-02-10 20:48:40,223 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:40,223 - INFO - Round 3 of 20.2025-02-10 20:48:40,863 - INFO - Generated 91 tokens in 0.64 seconds (142.18 tokens/sec)2025-02-10 20:48:40,864 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:40,864 - INFO - Round 4 of 20.2025-02-10 20:48:41,501 - INFO - Generated 91 tokens in 0.64 seconds (142.76 tokens/sec)2025-02-10 20:48:41,502 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:41,502 - INFO - Round 5 of 20.2025-02-10 20:48:42,140 - INFO - Generated 91 tokens in 0.64 seconds (142.55 tokens/sec)2025-02-10 20:48:42,141 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:42,141 - INFO - Round 6 of 20.2025-02-10 20:48:42,779 - INFO - Generated 91 tokens in 0.64 seconds (142.60 tokens/sec)2025-02-10 20:48:42,779 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:42,779 - INFO - Round 7 of 20.2025-02-10 20:48:43,416 - INFO - Generated 91 tokens in 0.64 seconds (142.91 tokens/sec)2025-02-10 20:48:43,417 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:43,417 - INFO - Round 8 of 20.2025-02-10 20:48:44,058 - INFO - Generated 91 tokens in 0.64 seconds (141.92 tokens/sec)2025-02-10 20:48:44,059 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:44,059 - INFO - Round 9 of 20.2025-02-10 20:48:44,700 - INFO - Generated 91 tokens in 0.64 seconds (141.97 tokens/sec)2025-02-10 20:48:44,700 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:44,700 - INFO - Round 10 of 20.2025-02-10 20:48:45,339 - INFO - Generated 91 tokens in 0.64 seconds (142.54 tokens/sec)2025-02-10 20:48:45,339 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:45,339 - INFO - Round 11 of 20.2025-02-10 20:48:45,977 - INFO - Generated 91 tokens in 0.64 seconds (142.78 tokens/sec)2025-02-10 20:48:45,977 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:45,977 - INFO - Round 12 of 20.2025-02-10 20:48:46,617 - INFO - Generated 91 tokens in 0.64 seconds (142.24 tokens/sec)2025-02-10 20:48:46,618 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:46,618 - INFO - Round 13 of 20.2025-02-10 20:48:47,264 - INFO - Generated 91 tokens in 0.65 seconds (140.87 tokens/sec)2025-02-10 20:48:47,264 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:47,265 - INFO - Round 14 of 20.2025-02-10 20:48:47,912 - INFO - Generated 91 tokens in 0.65 seconds (140.53 tokens/sec)2025-02-10 20:48:47,913 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:47,913 - INFO - Round 15 of 20.2025-02-10 20:48:48,561 - INFO - Generated 91 tokens in 0.65 seconds (140.36 tokens/sec)2025-02-10 20:48:48,562 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:48,562 - INFO - Round 16 of 20.2025-02-10 20:48:49,185 - INFO - Generated 91 tokens in 0.62 seconds (146.20 tokens/sec)2025-02-10 20:48:49,185 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:49,185 - INFO - Round 17 of 20.2025-02-10 20:48:49,805 - INFO - Generated 91 tokens in 0.62 seconds (146.85 tokens/sec)2025-02-10 20:48:49,805 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:49,805 - INFO - Round 18 of 20.2025-02-10 20:48:50,426 - INFO - Generated 91 tokens in 0.62 seconds (146.54 tokens/sec)2025-02-10 20:48:50,427 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:50,427 - INFO - Round 19 of 20.2025-02-10 20:48:51,048 - INFO - Generated 91 tokens in 0.62 seconds (146.67 tokens/sec)2025-02-10 20:48:51,048 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:51,048 - INFO - Round 20 of 20.2025-02-10 20:48:51,667 - INFO - Generated 91 tokens in 0.62 seconds (146.98 tokens/sec)2025-02-10 20:48:51,668 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:51,668 - INFO - Run details saved to gpt2-rates-2025-02-10-20-48-51-668280-torch-compile-reduce-overhead-fullgraph-py313t.json.torch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedtorch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedreal 0m19.806suser 0m23.067ssys 0m1.336sException ignored in: <_io.BufferedWriter name=41>BrokenPipeError: [Errno 32] Broken pipe2025-02-10 20:48:57,852 - INFO - Loaded /mnt/raid1/trent/src/parallelopedia/data/model_19072.pt checkpoint in 0.472 seconds.2025-02-10 20:48:57,859 - INFO - Initialized GPT model in 0.007 seconds.2025-02-10 20:48:58,087 - INFO - Loaded model weights in 0.228 seconds.<frozen importlib._bootstrap>:488: RuntimeWarning: The global interpreter lock (GIL)has been enabled to load module 'triton._C.libtriton', which has not declared that it can run safely without the GIL. To override this behavior and keep the GIL disabled (at your own risk), run with PYTHON_GIL=0 or -Xgil=0.2025-02-10 20:48:58,455 - INFO - Created GPT model in 0.603 seconds.2025-02-10 20:48:58,540 - INFO - Moved model to cuda:3 in 0.085 seconds.2025-02-10 20:48:58,540 - INFO - Loaded model from step 19072, val_loss 3.05197024345397952025-02-10 20:48:58,541 - INFO - Loaded gpt2 on cuda:3 in 1.161 seconds.2025-02-10 20:48:58,543 - INFO - torch.compiled model in 0.003 seconds.2025-02-10 20:48:58,543 - INFO - Round 1 of 20.2025-02-10 20:48:59,366 - INFO - Generated 91 tokens in 0.82 seconds (110.79 tokens/sec)2025-02-10 20:48:59,366 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:48:59,366 - INFO - Round 2 of 20.2025-02-10 20:49:00,019 - INFO - Generated 91 tokens in 0.65 seconds (139.51 tokens/sec)2025-02-10 20:49:00,019 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:00,019 - INFO - Round 3 of 20.2025-02-10 20:49:00,706 - INFO - Generated 91 tokens in 0.69 seconds (132.47 tokens/sec)2025-02-10 20:49:00,707 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:00,707 - INFO - Round 4 of 20.2025-02-10 20:49:01,327 - INFO - Generated 91 tokens in 0.62 seconds (146.66 tokens/sec)2025-02-10 20:49:01,328 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:01,328 - INFO - Round 5 of 20.2025-02-10 20:49:01,950 - INFO - Generated 91 tokens in 0.62 seconds (146.40 tokens/sec)2025-02-10 20:49:01,950 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:01,950 - INFO - Round 6 of 20.2025-02-10 20:49:02,571 - INFO - Generated 91 tokens in 0.62 seconds (146.62 tokens/sec)2025-02-10 20:49:02,571 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:02,571 - INFO - Round 7 of 20.2025-02-10 20:49:03,193 - INFO - Generated 91 tokens in 0.62 seconds (146.54 tokens/sec)2025-02-10 20:49:03,193 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:03,193 - INFO - Round 8 of 20.2025-02-10 20:49:03,816 - INFO - Generated 91 tokens in 0.62 seconds (146.13 tokens/sec)2025-02-10 20:49:03,816 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:03,816 - INFO - Round 9 of 20.2025-02-10 20:49:04,444 - INFO - Generated 91 tokens in 0.63 seconds (145.02 tokens/sec)2025-02-10 20:49:04,444 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:04,444 - INFO - Round 10 of 20.2025-02-10 20:49:05,066 - INFO - Generated 91 tokens in 0.62 seconds (146.47 tokens/sec)2025-02-10 20:49:05,066 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:05,066 - INFO - Round 11 of 20.2025-02-10 20:49:05,688 - INFO - Generated 91 tokens in 0.62 seconds (146.45 tokens/sec)2025-02-10 20:49:05,688 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:05,688 - INFO - Round 12 of 20.2025-02-10 20:49:06,310 - INFO - Generated 91 tokens in 0.62 seconds (146.37 tokens/sec)2025-02-10 20:49:06,311 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:06,311 - INFO - Round 13 of 20.2025-02-10 20:49:06,933 - INFO - Generated 91 tokens in 0.62 seconds (146.27 tokens/sec)2025-02-10 20:49:06,933 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:06,933 - INFO - Round 14 of 20.2025-02-10 20:49:07,554 - INFO - Generated 91 tokens in 0.62 seconds (146.56 tokens/sec)2025-02-10 20:49:07,555 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:07,555 - INFO - Round 15 of 20.2025-02-10 20:49:08,176 - INFO - Generated 91 tokens in 0.62 seconds (146.45 tokens/sec)2025-02-10 20:49:08,177 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:08,177 - INFO - Round 16 of 20.2025-02-10 20:49:08,799 - INFO - Generated 91 tokens in 0.62 seconds (146.39 tokens/sec)2025-02-10 20:49:08,799 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:08,799 - INFO - Round 17 of 20.2025-02-10 20:49:09,420 - INFO - Generated 91 tokens in 0.62 seconds (146.49 tokens/sec)2025-02-10 20:49:09,421 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:09,421 - INFO - Round 18 of 20.2025-02-10 20:49:10,043 - INFO - Generated 91 tokens in 0.62 seconds (146.24 tokens/sec)2025-02-10 20:49:10,044 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:10,044 - INFO - Round 19 of 20.2025-02-10 20:49:10,665 - INFO - Generated 91 tokens in 0.62 seconds (146.62 tokens/sec)2025-02-10 20:49:10,665 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:10,665 - INFO - Round 20 of 20.2025-02-10 20:49:11,286 - INFO - Generated 91 tokens in 0.62 seconds (146.54 tokens/sec)2025-02-10 20:49:11,286 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:11,287 - INFO - Run details saved to gpt2-rates-2025-02-10-20-49-11-287012-torch-compile-reduce-overhead-fullgraph-py313t.json.torch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedtorch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedreal 0m19.626suser 0m22.831ssys 0m1.406sException ignored in: <_io.BufferedWriter name=41>BrokenPipeError: [Errno 32] Broken pipe2025-02-10 20:49:17,458 - INFO - Loaded /mnt/raid1/trent/src/parallelopedia/data/model_19072.pt checkpoint in 0.491 seconds.2025-02-10 20:49:17,465 - INFO - Initialized GPT model in 0.006 seconds.2025-02-10 20:49:17,704 - INFO - Loaded model weights in 0.239 seconds.<frozen importlib._bootstrap>:488: RuntimeWarning: The global interpreter lock (GIL)has been enabled to load module 'triton._C.libtriton', which has not declared that it can run safely without the GIL. To override this behavior and keep the GIL disabled (at your own risk), run with PYTHON_GIL=0 or -Xgil=0.2025-02-10 20:49:18,069 - INFO - Created GPT model in 0.611 seconds.2025-02-10 20:49:18,153 - INFO - Moved model to cuda:3 in 0.084 seconds.2025-02-10 20:49:18,153 - INFO - Loaded model from step 19072, val_loss 3.05197024345397952025-02-10 20:49:18,153 - INFO - Loaded gpt2 on cuda:3 in 1.187 seconds.2025-02-10 20:49:18,157 - INFO - torch.compiled model in 0.003 seconds.2025-02-10 20:49:18,157 - INFO - Round 1 of 20.2025-02-10 20:49:18,973 - INFO - Generated 91 tokens in 0.82 seconds (111.60 tokens/sec)2025-02-10 20:49:18,973 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:18,973 - INFO - Round 2 of 20.2025-02-10 20:49:19,621 - INFO - Generated 91 tokens in 0.65 seconds (140.58 tokens/sec)2025-02-10 20:49:19,621 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:19,621 - INFO - Round 3 of 20.2025-02-10 20:49:20,266 - INFO - Generated 91 tokens in 0.64 seconds (141.26 tokens/sec)2025-02-10 20:49:20,266 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:20,266 - INFO - Round 4 of 20.2025-02-10 20:49:20,911 - INFO - Generated 91 tokens in 0.64 seconds (141.24 tokens/sec)2025-02-10 20:49:20,911 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:20,911 - INFO - Round 5 of 20.2025-02-10 20:49:21,555 - INFO - Generated 91 tokens in 0.64 seconds (141.46 tokens/sec)2025-02-10 20:49:21,555 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:21,555 - INFO - Round 6 of 20.2025-02-10 20:49:22,199 - INFO - Generated 91 tokens in 0.64 seconds (141.27 tokens/sec)2025-02-10 20:49:22,200 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:22,200 - INFO - Round 7 of 20.2025-02-10 20:49:22,846 - INFO - Generated 91 tokens in 0.65 seconds (140.98 tokens/sec)2025-02-10 20:49:22,846 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:22,846 - INFO - Round 8 of 20.2025-02-10 20:49:23,486 - INFO - Generated 91 tokens in 0.64 seconds (142.29 tokens/sec)2025-02-10 20:49:23,487 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:23,487 - INFO - Round 9 of 20.2025-02-10 20:49:24,129 - INFO - Generated 91 tokens in 0.64 seconds (141.71 tokens/sec)2025-02-10 20:49:24,130 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:24,130 - INFO - Round 10 of 20.2025-02-10 20:49:24,771 - INFO - Generated 91 tokens in 0.64 seconds (141.83 tokens/sec)2025-02-10 20:49:24,772 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:24,772 - INFO - Round 11 of 20.2025-02-10 20:49:25,413 - INFO - Generated 91 tokens in 0.64 seconds (142.00 tokens/sec)2025-02-10 20:49:25,413 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:25,414 - INFO - Round 12 of 20.2025-02-10 20:49:26,056 - INFO - Generated 91 tokens in 0.64 seconds (141.66 tokens/sec)2025-02-10 20:49:26,056 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:26,057 - INFO - Round 13 of 20.2025-02-10 20:49:26,703 - INFO - Generated 91 tokens in 0.65 seconds (140.87 tokens/sec)2025-02-10 20:49:26,703 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:26,703 - INFO - Round 14 of 20.2025-02-10 20:49:27,344 - INFO - Generated 91 tokens in 0.64 seconds (141.96 tokens/sec)2025-02-10 20:49:27,345 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:27,345 - INFO - Round 15 of 20.2025-02-10 20:49:27,986 - INFO - Generated 91 tokens in 0.64 seconds (142.08 tokens/sec)2025-02-10 20:49:27,986 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:27,986 - INFO - Round 16 of 20.2025-02-10 20:49:28,628 - INFO - Generated 91 tokens in 0.64 seconds (141.93 tokens/sec)2025-02-10 20:49:28,628 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:28,628 - INFO - Round 17 of 20.2025-02-10 20:49:29,269 - INFO - Generated 91 tokens in 0.64 seconds (142.10 tokens/sec)2025-02-10 20:49:29,269 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:29,269 - INFO - Round 18 of 20.2025-02-10 20:49:29,913 - INFO - Generated 91 tokens in 0.64 seconds (141.43 tokens/sec)2025-02-10 20:49:29,913 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:29,913 - INFO - Round 19 of 20.2025-02-10 20:49:30,557 - INFO - Generated 91 tokens in 0.64 seconds (141.46 tokens/sec)2025-02-10 20:49:30,557 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:30,557 - INFO - Round 20 of 20.2025-02-10 20:49:31,199 - INFO - Generated 91 tokens in 0.64 seconds (141.84 tokens/sec)2025-02-10 20:49:31,199 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:31,200 - INFO - Run details saved to gpt2-rates-2025-02-10-20-49-31-199651-torch-compile-max-autotune-py313t.json.torch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedtorch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedreal 0m19.900suser 0m23.112ssys 0m1.409sException ignored in: <_io.BufferedWriter name=41>BrokenPipeError: [Errno 32] Broken pipe2025-02-10 20:49:37,383 - INFO - Loaded /mnt/raid1/trent/src/parallelopedia/data/model_19072.pt checkpoint in 0.482 seconds.2025-02-10 20:49:37,390 - INFO - Initialized GPT model in 0.006 seconds.2025-02-10 20:49:37,613 - INFO - Loaded model weights in 0.223 seconds.<frozen importlib._bootstrap>:488: RuntimeWarning: The global interpreter lock (GIL)has been enabled to load module 'triton._C.libtriton', which has not declared that it can run safely without the GIL. To override this behavior and keep the GIL disabled (at your own risk), run with PYTHON_GIL=0 or -Xgil=0.2025-02-10 20:49:37,976 - INFO - Created GPT model in 0.592 seconds.2025-02-10 20:49:38,058 - INFO - Moved model to cuda:3 in 0.082 seconds.2025-02-10 20:49:38,058 - INFO - Loaded model from step 19072, val_loss 3.05197024345397952025-02-10 20:49:38,058 - INFO - Loaded gpt2 on cuda:3 in 1.157 seconds.2025-02-10 20:49:38,061 - INFO - torch.compiled model in 0.002 seconds.2025-02-10 20:49:38,061 - INFO - Round 1 of 20.2025-02-10 20:49:38,896 - INFO - Generated 91 tokens in 0.83 seconds (109.09 tokens/sec)2025-02-10 20:49:38,896 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:38,896 - INFO - Round 2 of 20.2025-02-10 20:49:39,552 - INFO - Generated 91 tokens in 0.66 seconds (138.68 tokens/sec)2025-02-10 20:49:39,553 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:39,553 - INFO - Round 3 of 20.2025-02-10 20:49:40,199 - INFO - Generated 91 tokens in 0.65 seconds (140.90 tokens/sec)2025-02-10 20:49:40,199 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:40,199 - INFO - Round 4 of 20.2025-02-10 20:49:40,851 - INFO - Generated 91 tokens in 0.65 seconds (139.61 tokens/sec)2025-02-10 20:49:40,852 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:40,852 - INFO - Round 5 of 20.2025-02-10 20:49:41,496 - INFO - Generated 91 tokens in 0.64 seconds (141.24 tokens/sec)2025-02-10 20:49:41,497 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:41,497 - INFO - Round 6 of 20.2025-02-10 20:49:42,143 - INFO - Generated 91 tokens in 0.65 seconds (140.79 tokens/sec)2025-02-10 20:49:42,144 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:42,144 - INFO - Round 7 of 20.2025-02-10 20:49:42,794 - INFO - Generated 91 tokens in 0.65 seconds (139.98 tokens/sec)2025-02-10 20:49:42,794 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:42,795 - INFO - Round 8 of 20.2025-02-10 20:49:43,442 - INFO - Generated 91 tokens in 0.65 seconds (140.59 tokens/sec)2025-02-10 20:49:43,442 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:43,442 - INFO - Round 9 of 20.2025-02-10 20:49:44,088 - INFO - Generated 91 tokens in 0.65 seconds (141.07 tokens/sec)2025-02-10 20:49:44,088 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:44,088 - INFO - Round 10 of 20.2025-02-10 20:49:44,738 - INFO - Generated 91 tokens in 0.65 seconds (140.10 tokens/sec)2025-02-10 20:49:44,738 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:44,738 - INFO - Round 11 of 20.2025-02-10 20:49:45,382 - INFO - Generated 91 tokens in 0.64 seconds (141.34 tokens/sec)2025-02-10 20:49:45,383 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:45,383 - INFO - Round 12 of 20.2025-02-10 20:49:46,026 - INFO - Generated 91 tokens in 0.64 seconds (141.47 tokens/sec)2025-02-10 20:49:46,027 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:46,027 - INFO - Round 13 of 20.2025-02-10 20:49:46,695 - INFO - Generated 91 tokens in 0.67 seconds (136.23 tokens/sec)2025-02-10 20:49:46,695 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:46,695 - INFO - Round 14 of 20.2025-02-10 20:49:47,339 - INFO - Generated 91 tokens in 0.64 seconds (141.40 tokens/sec)2025-02-10 20:49:47,340 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:47,340 - INFO - Round 15 of 20.2025-02-10 20:49:47,988 - INFO - Generated 91 tokens in 0.65 seconds (140.35 tokens/sec)2025-02-10 20:49:47,989 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:47,989 - INFO - Round 16 of 20.2025-02-10 20:49:48,649 - INFO - Generated 91 tokens in 0.66 seconds (137.84 tokens/sec)2025-02-10 20:49:48,649 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:48,649 - INFO - Round 17 of 20.2025-02-10 20:49:49,295 - INFO - Generated 91 tokens in 0.64 seconds (141.12 tokens/sec)2025-02-10 20:49:49,295 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:49,295 - INFO - Round 18 of 20.2025-02-10 20:49:49,939 - INFO - Generated 91 tokens in 0.64 seconds (141.44 tokens/sec)2025-02-10 20:49:49,939 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:49,939 - INFO - Round 19 of 20.2025-02-10 20:49:50,582 - INFO - Generated 91 tokens in 0.64 seconds (141.49 tokens/sec)2025-02-10 20:49:50,583 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:50,583 - INFO - Round 20 of 20.2025-02-10 20:49:51,232 - INFO - Generated 91 tokens in 0.65 seconds (140.31 tokens/sec)2025-02-10 20:49:51,232 - INFO - Output:Einstein's Theory of Relativity states that the speed oflight in a vacuum is simply the speed of the electrons inthat vacuum, since light has a speed. Since the speed oflight in a vacuum is equal to the speed of the electrons ina solid, the light from this source has a speed ofapproximately 1/299,792 m/s. Einstein's theory of relativityexplains this speed by a phenomenon known as the speed atthe end of time. In other words, this speed is2025-02-10 20:49:51,232 - INFO - Run details saved to gpt2-rates-2025-02-10-20-49-51-232526-torch-compile-max-autotune-fullgraph-py313t.json.torch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedtorch.__init__: Ignoring unsupported torch.compile()with no GIL unsupportedtorch._dynamo.eval_frame: Ignoring unsupported torch.compile()with no GIL unsupportedreal 0m20.061suser 0m23.272ssys 0m1.387sException ignored in: <_io.BufferedWriter name=41>BrokenPipeError: [Errno 32] Broken pipe
After each run, a JSON file is saved capturing details about the run. These can all be found in this json directory.
{"rates":[103.45683188097036,133.69021647689885,134.64711111769682,134.99572668445947,135.07252279473053,134.93995013027606,134.8097791019913,134.97874539755563,134.96994164914173,134.2781961646748,133.36089184927351,134.91936892711604,134.99927221213662,134.99574871578727,134.84820822748674,134.90753647890656,134.84542991060954,134.76236728440452,134.939303436755,134.95605543169106],"model_config":{"block_size":1024,"vocab_size":50304,"n_layer":12,"n_head":12,"n_embd":768},"args":{"log_level":"INFO","model":"gpt2","device":"cuda:3","max_length":100,"top_k":50,"seed":42,"prompt":"Einstein's Theory of Relativity states that","torch_compile":false,"torch_jit":false,"torch_compile_fullgraph":false,"torch_compile_reduce_overhead":false,"torch_compile_max_autotune":false,"generate_slim":false,"rounds":20,"wrap":60,"note":""},"start_timestamp":"2025-02-09T18:19:14.156224","end_timestamp":"2025-02-09T18:19:29.116383","elapsed":"14.960","device_name":"Tesla V100-DGXS-32GB","conda_env_name":"py313t","is_gil_enabled":false,"note":""}
All tests were done for 20 rounds, and the rates key contains an array of floats, representing the tokens/sec generation rate achieved by the call to generate() (or generate_slim(), which we’ll discuss shortly) after optionally compiling the model with the requested parameters.
Free-Threaded Python (py313t)
I ended up doing multiple runs because… well, as you’re about to see from the data visualization below… it was a pretty noisy test. Loads of variance and generally everything was all over the shop. The total run times for each different permutation were all about the same, around 14 seconds or so, but definitely no clear winner regarding whether or not torch.compile() or any particular permutation was having a repeatable speedup. Interesting.
Line plots and box-plots follow. The box plots omitted the first “warmup” run (which was always slower and skewed the data unnecessarily). If you click on the image you should get a nice gallery presentation mode that allows you to flip between images nicely.
Normal Python (py313)
Normal Python (not just free-threaded Python with the GIL enabled, but full-blown old-school normal Python with no knowledge of GIL removal—i.e. our py313 environment) was pretty similar:
Generation Performance (py313)
Generation Performance (py313)
Explicit @torch.compile Decorator
Here’s where it gets interesting. In a final, last-ditch effort to see if I could see any sort of speedup from torch.compile(), I introduced a slimmer generate() routine, aptly named generate_slim(), which was stripped of any superfluous code that would have otherwise impeded Torch Dynamo’s ability to optimize the graph. That function looked like this:
class GPT: ...# @torch.compiledef generate_slim(self, text_tokens: torch.Tensor, max_length: int=1024, top_k: int=50, seed: int=None, ) ->str:""" Generate text from the model. This version differs from `generate()` in that it does not use any Python code that causes a torch graph break. Args: text (str): Supplies the prompt. max_length (int): Supplies the maximum total length, including prompt. top_k (int): Supplies the number of tokens to consider at each generation step. seed (int): Ignored! Returns: str: The generated text (including the initial prompt). """# Initialize alias. device =self.device stop_token =self.stop_token# Create the tensor for capturing predicted tokens. x = torch.tensor( text_tokens, dtype=torch.long, device=device ).unsqueeze(0)# Create a random generator for reproducibility.# sample_rng = torch.Generator(device=device)# if seed is None:# seed = self.manual_seed# sample_rng.manual_seed(seed)# Generate tokens up to our max length.for _ inrange(max_length):with torch.no_grad():# Forward pass, ignoring the returned loss. (logits, _) =self(x)# Take the logits at the last time-step (shape:# (1, vocab_size)). logits = logits[:, -1, :]# Convert to probabilities. probs = F.softmax(logits, dim=-1)# Top-k sampling. topk_probs, topk_indices = torch.topk( probs, k=top_k, dim=-1, )# Sample the next token. next_idx = torch.multinomial( topk_probs, num_samples=1,# generator=sample_rng, ) next_token = torch.gather(topk_indices, -1, next_idx)# If the next token is the stop token, we're done.# next_token_item = next_token.item()# if next_token_item == stop_token:# break# Append token to current sequence. x = torch.cat((x, next_token), dim=1)return x
Note that we had to make a number of sizable modifications. No more random number generator—that was causing graph breaks. No more explicitly checking for the stop token, again, that makes the dynamic optimizer’s job much harder at runtime without extra tracing overhead for tracking scalars. So we now generate tokens up to the maximum specified, ignorant of any stop tokens, and return that.
I wanted to do two runs here, one where we call everything as normal with all the different torch.compile() invocations we’d used in prior runs, and then a second one where I explicitly mark the generate_slim() routine with a @torch.compile decorator.
The latter absolutely does not work in free-threaded Python, it segfaults after about ten seconds or so, thus, I couldn’t test it.
However, on the normal Python version… we finally saw some interesting results.
First, let’s look at our baseline: normal generate_slim() with no @torch.compile generator (the bash scripts verify that the decorator is uncommented and commented as necessary):
Generation Performance - generate_slim() - No @torch.compile Decorator (py313)
Generation Performance - generate_slim() - No @torch.compile Decorator (py313)
Well we finally see one configuration break out from the pack: apparently torch.compile(model, {'fullgraph': True}) yielded the best generation rate we’ve seen yet, hovering around 160 tokens/sec. Note that all the total run times are still pretty similar, hovering around that 14-15s mark.
Now, let’s uncomment the @torch.compile decorator above def generate_slim() and do another full run:
Generation Performance - generate_slim() - With @torch.compile Decorator (py313)
Generation Performance - generate_slim() - With @torch.compile Decorator (py313)
Oh man. That first compilation took forever. But once compiled, our tokens/sec generation rate shoots up significantly to the 250+ range instead of the 150+ range! But at a crazy up-front cost—as you can see with all of the run times (the values in parenthesis in the labels and x-axis in first and second plots, respectively) were in excess of four minutes.
As we’re only doing twenty runs of generating ~90-100 characters, that startup cost is brutal, however, if we were doing model training or launching a long-running inference service like our HTTP server, the startup cost would be quickly amortized away as we benefit from about a 75% speedup.
So, interesting stuff, kind of. To be fair… PyTorch clearly indicates free-threaded Python isn’t supported, per our hacks earlier, so, this should be an interesting area of development in future releases, especially now that we can see how much benefit there is to doing multi-threaded run-time inference in a single Python process.
Conclusion
In this article, we’ve explored the world of free-threaded Python and PyTorch, and demonstrated that you can now do parallel inference on PyTorch models, and that it all plays very nicely together when wrapped up with an asyncio-based HTTP server.
Hopefully this encourages more folks to experiment with free-threaded Python, or perhaps port their existing Python packages to play nicely when installed in a free-threaded Python environment. I personally can’t wait until free-threaded Python is the default! Although that’s probably at least five or so years out at this point.
Footnotes
I used 0.23.3, as that was the latest version available at the time, however, 0.23.4 has since been released, so you could try that too.↩︎
And I’m sure I used the existing Python stdlib http.server code at the time as the basis; ain’t nobody got time to be writing new web servers from scratch.↩︎
Unfortunately, it doesn’t appear to work on Windows as-is; using the exact same code, only one thread can be seen running when the server is loaded. It’s not doing a round-robin between all threads, like you’d expect to see with the GIL enabled, there’s just a single sole thread attempting to service all incoming requests, with all other threads sitting idle. I don’t know if it’s because of something quirky with regards to additional, non-main-thread threads not getting their own event loop (hopefully easy to fix), or something more insidious related to how we’re misuing I/O completion ports behind the scenes in IocpProactor() now that we have free-threading (much harder to fix). I haven’t had time to investigate in more detail.↩︎