Grok-1 (PyTorch Version)
This repository contains the model and weights of the torch version of Grok-1 open-weights model. You could find a complete example code of using the torch-version Grok-1 in ColossalAI GitHub Repository. We also applies parallelism techniques from ColossalAI framework (Tensor Parallelism for now) to accelerate the inference.
You could find the original weights released by xAI in Hugging Face and the original model in the Grok open release GitHub Repository.
Conversion
We translated the original model written in JAX into a PyTorch version, and converted the weights by mapping tensor files with parameter keys, de-quantizing the tensors with corresponding packed scales, and saving to checkpoint file with PyTorch APIs. BF16 weights are used for this PyTorch-versoin model.
Tokenizer
Grok-1 has a vocabulary size of 131072. For now, the original tokenizer is supposed to be used (i.e. tokenizer.model
in GitHub Repository) with the PyTorch-version model.
You should download the tokenizer from the official grok-1 repository.
wget https://github.com/xai-org/grok-1/raw/main/tokenizer.model
Usage
import torch
from modelscope import AutoModelForCausalLM, snapshot_download
from sentencepiece import SentencePieceProcessor
model_dir = snapshot_download('colossalai/grok-1-pytorch', revision='v1.0.0')
# tokenizer downloaded from the official grok-1 repository
sp = SentencePieceProcessor(model_file="tokenizer.model")
torch.set_default_dtype(torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16
).eval()
generation_kwargs = {
"max_new_tokens": 100,
"top_p": 0.95,
"temperature": 0.3
}
text = '明月松间照,\n\n->\n\n'
input_ids = sp.encode(text)
input_ids = torch.tensor([input_ids]).cuda()
attention_mask = torch.ones_like(input_ids)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
**generation_kwargs,
}
output = model.generate(**inputs)
output_text = sp.decode(output[0].tolist())[len(text):]
print(output_text)
It will take, depending on your Internet speed, several hours to tens of hours to download checkpoints (about 600G!), and 5-10 minutes to load checkpoints when it's ready to launch the inference. Don't worry, it's not stuck.
Note: A multi-GPU machine is required to test the model with the example code (For now, a 8x80G multi-GPU machine is required).
Performance
The following is a benchmark result for JAX version, PyTorch version run by HuggingFace Auto Device, and PyTorch version run by ColossalAI (TP). To try and test on ColossalAI, please refer to the example of using the torch-version Grok-1 in ColossalAI GitHub Repository.
For request of batch size set to 1 and maximum length set to 100:
Method | Initialization-Duration(sec) | Average-Generation-Latency(sec) |
---|---|---|
ColossalAI | 431.45 | 14.92 |
HuggingFace Auto-Device | 426.96 | 48.38 |
JAX | 147.61 | 56.25 |
Tested on 8x80G NVIDIA H800.
评论