MusiLingo-musicqa-v1

我要开发同款
匿名用户2024年07月31日
33阅读
所属分类ai、musilingo、pytorch、art、music
开源地址https://modelscope.cn/models/m-a-p/MusiLingo-musicqa-v1
授权协议cc-by-nc-4.0

作品详情

Model Card for Model ID

Model Details

Model Description

The model consists of a music encoder MERT-v1-300M, a natural language decoder vicuna-7b-delta-v0, and a linear projection laer between the two.

This checkpoint of MusiLingo is developed on the MusicQA and can answer instructions with music raw audio, such as querying about the tempo, emotion, genre, tags or subjective feelings etc. You can use the MusicQA dataset for the following demo. For the implementation of MusicQA, please refer to our Github repo.

Model Sources [optional]

Getting Start

from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2FeatureExtractor
from transformers import StoppingCriteria, StoppingCriteriaList



class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False

def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5,
        repetition_penalty=1.0, length_penalty=1, temperature=0.1, max_length=2000):
    audio = samples["audio"].cuda()
    audio_embeds, atts_audio = self.encode_audio(audio)
    if 'instruction_input' in samples:  # instruction dataset
        #print('Instruction Batch')
        instruction_prompt = []
        for instruction in samples['instruction_input']:
            prompt = '<Audio><AudioHere></Audio> ' + instruction
            instruction_prompt.append(self.prompt_template.format(prompt))
        audio_embeds, atts_audio = self.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
    self.llama_tokenizer.padding_side = "right"
    batch_size = audio_embeds.shape[0]
    bos = torch.ones([batch_size, 1],
                    dtype=torch.long,
                    device=torch.device('cuda')) * self.llama_tokenizer.bos_token_id
    bos_embeds = self.llama_model.model.embed_tokens(bos)
    atts_bos = atts_audio[:, :1]
    inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
    attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
    outputs = self.llama_model.generate(
        inputs_embeds=inputs_embeds,
        max_new_tokens=max_new_tokens,
        stopping_criteria=stopping,
        num_beams=num_beams,
        do_sample=True,
        min_length=min_length,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        temperature=temperature,
    )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
        output_token = output_token[1:]
    if output_token[0] == 1:  # if there is a start token <s> at the beginning. remove it
        output_token = output_token[1:]
    output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('###')[0]  # remove the stop sign '###'
    output_text = output_text.split('Assistant:')[-1].strip()
    return output_text

processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
ds = MusicQADataset(processor, f'{path}/data/music_data', 'Eval')
dl = DataLoader(
                ds,
                batch_size=1,
                num_workers=0,
                pin_memory=True,
                shuffle=False,
                drop_last=True,
                collate_fn=ds.collater
                )

stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
                                torch.tensor([2277, 29937]).cuda()])])

from transformers import AutoModel
model_musicqa = AutoModel.from_pretrained("m-a-p/MusiLingo-musicqa-v1")

for idx, sample in tqdm(enumerate(dl)):
    ans = answer(Musilingo_musicqa.model, sample, stopping, length_penalty=100, temperature=0.1)
    txt = sample['text_input'][0]
    print(txt)
    print(and)

Citing This Work

If you find the work useful for your research, please consider citing it using the following BibTeX entry:

@inproceedings{deng2024musilingo,
  title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response},
  author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil},
  booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)},
  year={2024},
  organization={Association for Computational Linguistics}
}
声明:本文仅代表作者观点,不代表本站立场。如果侵犯到您的合法权益,请联系我们删除侵权资源!如果遇到资源链接失效,请您通过评论或工单的方式通知管理员。未经允许,不得转载,本站所有资源文章禁止商业使用运营!
下载安装【程序员客栈】APP
实时对接需求、及时收发消息、丰富的开放项目需求、随时随地查看项目状态

评论