图像旋转判断模型

我要开发同款
匿名用户2024年07月31日
20阅读
所属分类ai
开源地址https://modelscope.cn/models/Cherrytest/rot_bgr
授权协议Apache License 2.0

作品详情

模型描述:

在人脸相关应用中,用户上传图像往往具有整体的90度整数倍的旋转,导致人脸角度与正向角度具有较大偏差,如果不加处理直接送入后续模型,会严重影响模型的输出结果。图像旋转判断模型可以对含有人脸的图片整体在平面内的角度进行判断,若图片中人脸是正向的,则模型输出旋转0度,若不是则根据图片中的实际朝向输出90度,180度和270度中的相应值,并以此旋转图像。

使用方式和范围:

使用方式:

输入图片,输出旋转后的图片。

目标场景:

人脸相关的前置基础能力,可应用于人像美颜/互动娱乐/人脸比对等场景的数据预处理阶段,尤其适应于端上人脸应用场景。

代码范例:

import cv2
import os
import numpy as np
import torch
from modelscope import snapshot_download
from PIL import Image
import onnxruntime

def softmax(x):
    x -= np.max(x, axis=0, keepdims=True)
    x = np.exp(x) / np.sum(np.exp(x), axis=0, keepdims=True)
    return x

def get_rot(image, ort_session):
    img_cv = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
    img_clone = img_cv.copy()
    img_np = cv2.resize(img_cv, (224, 224))
    img_np = img_np.astype(np.float32)
    mean = np.array([103.53, 116.28, 123.675], dtype=np.float32).reshape((1, 1, 3))
    norm = np.array([0.01742919, 0.017507, 0.01712475], dtype=np.float32).reshape((1, 1, 3))
    img_np = (img_np - mean) * norm
    img_tensor = torch.from_numpy(img_np)
    img_tensor = img_tensor.unsqueeze(0)
    img_nchw = img_tensor.permute(0, 3, 1, 2)
    ort_inputs = {ort_session.get_inputs()[0].name: img_nchw.numpy()}
    outputs = ort_session.run(None, ort_inputs)
    logits = outputs[0].reshape((-1,))
    probs = softmax(logits)
    rot_idx = np.argmax(probs)
    if rot_idx == 1:
        print('rot 90')
        img_clone = cv2.transpose(img_clone)
        img_clone = np.flip(img_clone, 1)
        return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
    elif rot_idx == 2:
        print('rot 180')
        img_clone = cv2.flip(img_clone, -1)
        return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
    elif rot_idx == 3:
        print('rot 270')
        img_clone = cv2.transpose(img_clone)
        img_clone = np.flip(img_clone, 0)
        return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
    else:
        return image

model_dir = snapshot_download('Cherrytest/rot_bgr', revision='v1.0.0')
model_path = os.path.join(model_dir, 'rot_bgr.onnx')
ort_session = onnxruntime.InferenceSession(model_path)
img_path = 'path_of_your_image'
image = Image.open(img_path)
image = image.convert('RGB')
image = get_rot(image, ort_session)
out_path = 'path_to_save_image'
image.save(out_path)

来源说明: 本模型及代码来自达摩院自研技术。

声明:本文仅代表作者观点,不代表本站立场。如果侵犯到您的合法权益,请联系我们删除侵权资源!如果遇到资源链接失效,请您通过评论或工单的方式通知管理员。未经允许,不得转载,本站所有资源文章禁止商业使用运营!
下载安装【程序员客栈】APP
实时对接需求、及时收发消息、丰富的开放项目需求、随时随地查看项目状态

评论