Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ncnn deployment examples #145

Merged
merged 14 commits into from
Jul 30, 2021
Prev Previous commit
Next Next commit
Adding onnx export tools
  • Loading branch information
zhiqwang committed Jul 26, 2021
commit 77f295a9695f8b124a177fd221cc89e55d56f106
6 changes: 3 additions & 3 deletions deployment/ncnn/main.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// This file is wirtten base on the following file:
// https://github.com/Tencent/ncnn/blob/master/examples/yolov5.cpp
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
@@ -10,7 +11,6 @@
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
// ------------------------------------------------------------------------------

#include "layer.h"
#include "net.h"
Empty file added tools/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions tools/export_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import argparse
import torch
from .yolort_deploy_friendly import yolov5_deploy_friendly


def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolov5s.pt',
help='weights path')
parser.add_argument('--img_size', nargs='+', type=int, default=[640, 640],
help='image (height, width)')
parser.add_argument('--num_classes', type=int, default=80,
help='number of classes')
parser.add_argument('--batch_size', type=int, default=1,
help='batch size')
parser.add_argument('--device', default='cpu',
help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--half', action='store_true',
help='FP16 half-precision export')
parser.add_argument('--dynamic', action='store_true',
help='ONNX: dynamic axes')
parser.add_argument('--simplify', action='store_true',
help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=11,
help='ONNX: opset version')
return parser


def cli_main():
parser = get_parser()
args = parser.parse_args()
print(args)
export_onnx(args)


def export_onnx(args):

model = yolov5_deploy_friendly(
pretrained=True,
num_classes=args.num_classes,
)
inputs = torch.rand(args.batch_size, 3, 320, 320)
outputs = model(inputs)
print(outputs.shape)


if __name__ == "__main__":
cli_main()
105 changes: 105 additions & 0 deletions tools/yolort_deploy_friendly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import torch
from torch import nn, Tensor

from torchvision.models.utils import load_state_dict_from_url

from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.box_head import YOLOHead

from typing import Any, List, Optional


def yolov5_deploy_friendly(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 80,
**kwargs: Any,
):
r"""yolov5 small release 4.0 model from
`"ultralytics/yolov5" <https://zenodo.org/badge/latestdoi/264818686>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
backbone_name = 'darknet_s_r4_0'
depth_multiple = 0.33
width_multiple = 0.5
version = 'r4.0'
backbone = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple, version=version)

model = YOLODeployFriendly(backbone, num_classes, **kwargs)

if pretrained:
model_urls_root = 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0'
model_url = f'{model_urls_root}/yolov5_darknet_pan_s_r40_coco-e3fd213d.pt'
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)

return model


class YOLODeployFriendly(nn.Module):
"""
Deployment Friendly Wrapper of YOLO.
"""
def __init__(
self,
backbone: nn.Module,
num_classes: int,
# Anchor parameters
anchor_grids: Optional[List[List[float]]] = None,
anchor_generator: Optional[nn.Module] = None,
head: Optional[nn.Module] = None,
):
super().__init__()
if not hasattr(backbone, "out_channels"):
raise ValueError(
"backbone should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the "
"same for all the levels)")
self.backbone = backbone

strides = [8, 16, 32]

if anchor_grids is None:
anchor_grids = [
[10, 13, 16, 30, 33, 23],
[30, 61, 62, 45, 59, 119],
[116, 90, 156, 198, 373, 326],
]

if anchor_generator is None:
anchor_generator = AnchorGenerator(strides, anchor_grids)
self.anchor_generator = anchor_generator

if head is None:
head = YOLOHead(
backbone.out_channels,
anchor_generator.num_anchors,
anchor_generator.strides,
num_classes,
)
self.head = head

def forward(self, samples: Tensor):
"""
Arguments:
samples (Tensor): batched images, of shape [batch_size x 3 x H x W]
"""
# get the features from the backbone
features = self.backbone(samples)

# compute the yolo heads outputs using the features
head_outputs = self.head(features)

all_pred_logits = []
batch_size, _, _, _, K = head_outputs[0].shape

for pred_logits in head_outputs:
pred_logits = pred_logits.reshape(batch_size, -1, K) # Size=(NN, HWA, K)
all_pred_logits.append(pred_logits)

all_pred_logits = torch.cat(all_pred_logits, dim=1)
return all_pred_logits