mamba-370m-hf

我要开发同款
匿名用户2024年07月31日
24阅读
所属分类ai、mamba、pytorch
开源地址https://modelscope.cn/models/AI-ModelScope/mamba-370m-hf

作品详情

Mamba

This repository contains the transfromers compatible mamba-2.8b. The checkpoints are untouched, but the full config.json and tokenizer are pushed to this repo.

Usage

You need to install transformers from main until transformers=4.39.0 is released.

pip install git+https://github.com/huggingface/transformers@main

We also recommend you to install both causal_conv_1d and mamba-ssm using:

pip install causal-conv1d>=1.2.0
pip install mamba-ssm

If any of these two is not installed, the "eager" implementation will be used. Otherwise the more optimised cuda kernels will be used.

Generation

You can use the classic generate API:

>>> from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
>>> import torch

>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-370m-hf")
>>> input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"]

>>> out = model.generate(input_ids, max_new_tokens=10)
>>> print(tokenizer.batch_decode(out))
["Hey how are you doing?\n\nI'm doing great.\n\nI"]

PEFT finetuning example

In order to finetune using the peft library, we recommend keeping the model in float32!

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-370m-hf")
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config =  LoraConfig(
        r=8,
        target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
        task_type="CAUSAL_LM",
        bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)
trainer.train()
声明:本文仅代表作者观点,不代表本站立场。如果侵犯到您的合法权益,请联系我们删除侵权资源!如果遇到资源链接失效,请您通过评论或工单的方式通知管理员。未经允许,不得转载,本站所有资源文章禁止商业使用运营!
下载安装【程序员客栈】APP
实时对接需求、及时收发消息、丰富的开放项目需求、随时随地查看项目状态

评论