Skip to main content

Model classes and pre-training utilities for a tiny version of Llama in PyTorch.

Project description

~ tinyllama ~

TinyLlama Logo

Model classes and pre-training utilities for a tiny version of Llama in PyTorch.

Installation 🚀

pip install tinyllama

Pre-training a model 🏋‍♀

Initializing a tokenizer

With a simple character-level tokenizer:

from tinyllama.tokenizers import CharacterTokenizer
tokenizer = CharacterTokenizer()
# '|' is the default eos_token
tokenizer.add_eos_tokens()

To turn a corpus into tokens:

tokens  = tokenizer.tokenize(corpus)

Initializing a Llama model

from tinyllama import Llama
model = Llama(context_window=500, emb_dim=10, n_heads=2, n_blocks=2)

Multi-Query attention

model = Llama(context_window=500, emb_dim=10, n_heads=2, n_blocks=2, gq_ratio=2)

The parameter gq_ratio represents the ratio $\frac{number \ of \ heads}{number \ of \ queries/keys}$, it is set to 1 by default.

The configuration above builds a Llama model with the number of heads being twice as much as the number of queries/keys.

Launching a pre-training job

from tinyllama import TrainConfig, Trainer
TrainConfig = TrainConfig(batch_size=32, epochs=50, log_interval=15)
Trainer = Trainer(TrainConfig)
Trainer.run(model, tokens)

Diagnosis 😷

Diagnosis class run a training job on a copy of the model and returns training information that could be useful to the user.

Diagnosing the learning rate

Returns a plot representing the loss for each learning rate, the scale for the argument start and end is logarithmic.

from tinyllama.diagnosis import LrDiagnose                                                                                                                                                                                                       LrDiagnose = LrDiagnose(start=-5, end=0, n_lrs=50)                                                                   # LrDiagnose.run(model, tokens, TrainConfig)
LrDiagnose = LrDiagnose(start=-5, end=0, n_lrs=50)
LrDiagnose.run(model, tokens, TrainConfig)

Diagnosing the gradients

Returns a histogram representing the distribution of the gradients, doesn't run additional training jobs.

from tinyllama.diagnosis import GradDiagnose
GradDiagnose = GradDiagnose(num_params_to_track=1500)
GradDiagnose.run(model)

Diagnosing the activation layers (SwiGLU layers)

Returns a histogram representing the distribution of the activation layers.

from tinyllama.diagnosis import SwigluDiagnose
SwigluDiagnose = SwigluDiagnose(num_embeddings_for_histogram=50, track_direction="forward" )
SwigluDiagnose.run(model, tokens, TrainConfig)

Diagnosing the gradients/data ratios

Returns a plot representing the gradient/data ratio in each step of the training.

from tinyllama.diagnosis import SwigluDiagnose
GdrDiagnose = GdrDiagnose(num_params_to_track=5, num_iters=150)
GdrDiagnose.run(model, tokens, TrainConfig)

Hyperparameter tuning with GPTune ⚙️

GPTune facilitates hyperparameter tuning by leveraging Gaussian Processes as a means to optimize the tuning process.

from tinyllama.gptuner import GPTuneConfig, GPTune
GPTuneConfig = GPTuneConfig(num_training_samples=100, hyperparams_to_tune=["epochs", "n_heads"], l_bounds=[10, 2], u_bounds=[50, 5], num_evaluations=500)
GPTune = GPTune(GPTuneConfig)
GPTune.run(model, tokens, TrainConfig)

Generating ✍

Generates a response to a prompt.

from tinyllama import generate
# kv_cache is set to True by default.
generate(model, prompt, max_tokens=900, kv_cache=True)

Parsing 📜

Parses single or multiple files.

# ".txt" files
from tinyllama.readers import get_text
corpus = get_text("./txt_path")

# ".pdf" files
from tinyllama.readers import get_pdf_text
corpus = get_pdf_text("./pdf_path")

To parse multiple files:

# ".txt" files
from tinyllama.readers import get_text
corpus = ''.join(get_text(pdf_path) for txt_path in txt_paths)

# ".pdf" files
from tinyllama.readers import get_pdf_text
corpus = ''.join(get_pdf_text(pdf_path) for pdf_path in pdf_paths)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tinyllama-0.0.291.tar.gz (17.9 kB view hashes)

Uploaded Source

Built Distribution

tinyllama-0.0.291-py3-none-any.whl (28.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page