grok-1-pytorch

我要开发同款
匿名用户2024年07月31日
33阅读
所属分类aipytorch
开源地址https://modelscope.cn/models/colossalai/grok-1-pytorch
授权协议apache-2.0

作品详情

Grok-1 (PyTorch Version)

This repository contains the model and weights of the torch version of Grok-1 open-weights model. You could find a complete example code of using the torch-version Grok-1 in ColossalAI GitHub Repository. We also applies parallelism techniques from ColossalAI framework (Tensor Parallelism for now) to accelerate the inference.

You could find the original weights released by xAI in Hugging Face and the original model in the Grok open release GitHub Repository.

Conversion

We translated the original model written in JAX into a PyTorch version, and converted the weights by mapping tensor files with parameter keys, de-quantizing the tensors with corresponding packed scales, and saving to checkpoint file with PyTorch APIs. BF16 weights are used for this PyTorch-versoin model.

Tokenizer

Grok-1 has a vocabulary size of 131072. For now, the original tokenizer is supposed to be used (i.e. tokenizer.model in GitHub Repository) with the PyTorch-version model.

You should download the tokenizer from the official grok-1 repository.

wget https://github.com/xai-org/grok-1/raw/main/tokenizer.model

Usage

import torch

from modelscope import AutoModelForCausalLM, snapshot_download
from sentencepiece import SentencePieceProcessor

model_dir = snapshot_download('colossalai/grok-1-pytorch', revision='v1.0.0')

# tokenizer downloaded from the official grok-1 repository
sp = SentencePieceProcessor(model_file="tokenizer.model")

torch.set_default_dtype(torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(
    model_dir, 
    trust_remote_code=True, 
    device_map="auto", 
    torch_dtype=torch.bfloat16
).eval()
generation_kwargs = {
    "max_new_tokens": 100, 
    "top_p": 0.95, 
    "temperature": 0.3
}

text = '明月松间照,\n\n->\n\n'
input_ids = sp.encode(text)
input_ids = torch.tensor([input_ids]).cuda()
attention_mask = torch.ones_like(input_ids)
inputs = {
    "input_ids": input_ids,
    "attention_mask": attention_mask,
    **generation_kwargs,
}

output = model.generate(**inputs)
output_text = sp.decode(output[0].tolist())[len(text):]

print(output_text)

It will take, depending on your Internet speed, several hours to tens of hours to download checkpoints (about 600G!), and 5-10 minutes to load checkpoints when it's ready to launch the inference. Don't worry, it's not stuck.

Note: A multi-GPU machine is required to test the model with the example code (For now, a 8x80G multi-GPU machine is required).

Performance

The following is a benchmark result for JAX version, PyTorch version run by HuggingFace Auto Device, and PyTorch version run by ColossalAI (TP). To try and test on ColossalAI, please refer to the example of using the torch-version Grok-1 in ColossalAI GitHub Repository.

For request of batch size set to 1 and maximum length set to 100:

Method Initialization-Duration(sec) Average-Generation-Latency(sec)
ColossalAI 431.45 14.92
HuggingFace Auto-Device 426.96 48.38
JAX 147.61 56.25

Tested on 8x80G NVIDIA H800.

声明:本文仅代表作者观点,不代表本站立场。如果侵犯到您的合法权益,请联系我们删除侵权资源!如果遇到资源链接失效,请您通过评论或工单的方式通知管理员。未经允许,不得转载,本站所有资源文章禁止商业使用运营!
下载安装【程序员客栈】APP
实时对接需求、及时收发消息、丰富的开放项目需求、随时随地查看项目状态

评论