tresnet_m.miil_in21k

我要开发同款
匿名用户2024年07月31日
15阅读
开发技术pytorch
所属分类ai、timm、image-classification
开源地址https://modelscope.cn/models/timm/tresnet_m.miil_in21k
授权协议apache-2.0

作品详情

Model card for tresnetm.miilin21k

A TResNet image classification model. Trained on ImageNet-21K-P ("ImageNet-21K Pretraining for the Masses", a 11k subset of ImageNet-22k) by paper authors.

The weights for this model have been remapped and modified from the originals to work with standard BatchNorm instead of InplaceABN. inplace_abn can be problematic to build recently and ends up slower with memory_format=channels_last, torch.compile(), etc.

Model Details

  • Model Type: Image classification / feature backbone
  • Model Stats:
  • Params (M): 52.3
  • GMACs: 5.8
  • Activations (M): 7.3
  • Image size: 224 x 224
  • Papers:
  • TResNet: High Performance GPU-Dedicated Architecture: https://arxiv.org/abs/2003.13630
  • ImageNet-21K Pretraining for the Masses: https://arxiv.org/abs/2104.10972
  • Pretrain Dataset: ImageNet-21K-P
  • Original:
  • https://github.com/Alibaba-MIIL/TResNet
  • https://github.com/Alibaba-MIIL/ImageNet21K

Model Usage

Image Classification

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model('tresnet_m.miil_in21k', pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

Feature Map Extraction

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model(
    'tresnet_m.miil_in21k',
    pretrained=True,
    features_only=True,
)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

for o in output:
    # print shape of each feature map in output
    # e.g.:
    #  torch.Size([1, 64, 56, 56])
    #  torch.Size([1, 128, 28, 28])
    #  torch.Size([1, 1024, 14, 14])
    #  torch.Size([1, 2048, 7, 7])

    print(o.shape)

Image Embeddings

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model(
    'tresnet_m.miil_in21k',
    pretrained=True,
    num_classes=0,  # remove classifier nn.Linear
)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # output is (batch_size, num_features) shaped tensor

# or equivalently (without needing to set num_classes=0)

output = model.forward_features(transforms(img).unsqueeze(0))
# output is unpooled, a (1, 2048, 7, 7) shaped tensor

output = model.forward_head(output, pre_logits=True)
# output is a (1, num_features) shaped tensor

Citation

@misc{ridnik2020tresnet,
    title={TResNet: High Performance GPU-Dedicated Architecture},
    author={Tal Ridnik and Hussam Lawen and Asaf Noy and Itamar Friedman},
    year={2020},
    eprint={2003.13630},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{ridnik2021imagenet21k,
  title={ImageNet-21K Pretraining for the Masses}, 
  author={Tal Ridnik and Emanuel Ben-Baruch and Asaf Noy and Lihi Zelnik-Manor},
  year={2021},
  eprint={2104.10972},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}
声明:本文仅代表作者观点,不代表本站立场。如果侵犯到您的合法权益,请联系我们删除侵权资源!如果遇到资源链接失效,请您通过评论或工单的方式通知管理员。未经允许,不得转载,本站所有资源文章禁止商业使用运营!
下载安装【程序员客栈】APP
实时对接需求、及时收发消息、丰富的开放项目需求、随时随地查看项目状态

评论