零样本文本分类-SSTuning-base-英语

我要开发同款
匿名用户2024年07月31日
25阅读
所属分类ai、roberta、pytorch
开源地址https://modelscope.cn/models/iic/zero-shot-classify-SSTuning-base
授权协议mit

作品详情

模型介绍

快速开始

Pipeline

from modelscope.pipelines import pipeline

model = 'damo/zero-shot-classify-SSTuning-base'
pipe = pipeline('zero-shot-classify-sstuning', model=model, model_revision='v1.0.2')

text = "I love this place! The food is always so fresh and delicious."
list_label = ["negative", "positive"]

output = pipe(text,list_label = list_label)
print(output)
# {'prediction': 'positive.', 'probability': '99.92%'}

python

from modelscope import AutoTokenizer, AutoModelForSequenceClassification, snapshot_download
import torch, string, random

model_dir = snapshot_download("damo/zero-shot-classify-SSTuning-base", revision='v1.0.2')

tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)

text = "I love this place! The food is always so fresh and delicious."
list_label = ["negative", "positive"] #The number of labels should be 2 ~ 20. 

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
list_ABC = [x for x in string.ascii_uppercase]

def check_text(model, text, list_label, shuffle=False): 
    list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
    list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
    if shuffle: 
        random.shuffle(list_label_new)
    s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
    text = f'{s_option} {tokenizer.sep_token} {text}'

    model.to(device).eval()
    encoding = tokenizer([text],truncation=True, max_length=512,return_tensors='pt')
    item = {key: val.to(device) for key, val in encoding.items()}
    logits = model(**item).logits

    logits = logits if shuffle else logits[:,0:len(list_label)]
    probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
    predictions = torch.argmax(logits, dim=-1).item() 
    probabilities = [round(x,5) for x in probs[0]]

    print(f'prediction:    {predictions} => ({list_ABC[predictions]}) {list_label_new[predictions]}')
    print(f'probability:   {round(probabilities[predictions]*100,2)}%')

check_text(model, text, list_label)
# prediction:    1 => (B) positive.
# probability:   99.92%

相关论文以及引用信息

@inproceedings{acl23/SSTuning,
  author    = {Chaoqun Liu and
               Wenxuan Zhang and
               Guizhen Chen and
               Xiaobao Wu and
               Anh Tuan Luu and
               Chip Hong Chang and 
               Lidong Bing},
  title     = {Zero-Shot Text Classification via Self-Supervised Tuning},
  booktitle = {Findings of the Association for Computational Linguistics: ACL 2023},
  year      = {2023},
  url       = {https://arxiv.org/abs/2305.11442},
}
声明:本文仅代表作者观点,不代表本站立场。如果侵犯到您的合法权益,请联系我们删除侵权资源!如果遇到资源链接失效,请您通过评论或工单的方式通知管理员。未经允许,不得转载,本站所有资源文章禁止商业使用运营!
下载安装【程序员客栈】APP
实时对接需求、及时收发消息、丰富的开放项目需求、随时随地查看项目状态

评论