Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move load_from_ultralytics into _checkpoint.py #373

Merged
merged 7 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.io import read_image
from yolort import models
from yolort.models import YOLOv5
from yolort.models._utils import load_from_ultralytics
from yolort.models._checkpoint import load_from_ultralytics
from yolort.utils import get_image_from_url, read_image_to_tensor
from yolort.utils.image_utils import box_cxcywh_to_xyxy
from yolort.v5 import letterbox, scale_coords, attempt_download
Expand Down
216 changes: 216 additions & 0 deletions yolort/models/_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) 2020, yolort team. All rights reserved.

from functools import reduce
from typing import Dict, List, Optional

from torch import nn
from yolort.v5 import get_yolov5_size, load_yolov5_model

from .backbone_utils import darknet_pan_backbone
from .box_head import YOLOHead

__all__ = ["load_from_ultralytics"]


def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
"""
Allows the user to load model state file from the checkpoint trained from
the ultralytics/yolov5.

Args:
checkpoint_path (str): Path of the YOLOv5 checkpoint model.
version (str): upstream version released by the ultralytics/yolov5, Possible
values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0".
"""

if version not in ["r3.1", "r4.0", "r6.0"]:
raise NotImplementedError(
f"Currently does not support version: {version}. Feel free to file an issue "
"labeled enhancement to us."
)

checkpoint_yolov5 = load_yolov5_model(checkpoint_path)
num_classes = checkpoint_yolov5.yaml["nc"]
strides = checkpoint_yolov5.stride
# YOLOv5 will change the anchors setting when using the auto-anchor mechanism. So we
# use the following formula to compute the anchor_grids instead of attaching it via
# checkpoint_yolov5.yaml["anchors"]
num_anchors = checkpoint_yolov5.model[-1].anchors.shape[1]
anchor_grids = (
(checkpoint_yolov5.model[-1].anchors * checkpoint_yolov5.model[-1].stride.view(-1, 1, 1))
.reshape(1, -1, 2 * num_anchors)
.tolist()[0]
)

depth_multiple = checkpoint_yolov5.yaml["depth_multiple"]
width_multiple = checkpoint_yolov5.yaml["width_multiple"]

use_p6 = False
if len(strides) == 4:
use_p6 = True

if use_p6:
inner_block_maps = {"0": "11", "1": "12", "3": "15", "4": "16", "6": "19", "7": "20"}
layer_block_maps = {"0": "23", "1": "24", "2": "26", "3": "27", "4": "29", "5": "30", "6": "32"}
p6_block_maps = {"0": "9", "1": "10"}
head_ind = 33
head_name = "m"
else:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}
p6_block_maps = None
head_ind = 24
head_name = "m"

convert_yolo_checkpoint = CheckpointConverter(
depth_multiple,
width_multiple,
inner_block_maps=inner_block_maps,
layer_block_maps=layer_block_maps,
p6_block_maps=p6_block_maps,
strides=strides,
anchor_grids=anchor_grids,
head_ind=head_ind,
head_name=head_name,
num_classes=num_classes,
version=version,
use_p6=use_p6,
)
convert_yolo_checkpoint.updating(checkpoint_yolov5)
state_dict = convert_yolo_checkpoint.model.half().state_dict()

size = get_yolov5_size(depth_multiple, width_multiple)

return {
"num_classes": num_classes,
"depth_multiple": depth_multiple,
"width_multiple": width_multiple,
"strides": strides,
"anchor_grids": anchor_grids,
"use_p6": use_p6,
"size": size,
"state_dict": state_dict,
}


class ModelWrapper(nn.Module):
def __init__(self, backbone, head):
super().__init__()
self.backbone = backbone
self.head = head


class CheckpointConverter:
"""
Update checkpoint from ultralytics yolov5.
"""

def __init__(
self,
depth_multiple: float,
width_multiple: float,
inner_block_maps: Optional[Dict[str, str]] = None,
layer_block_maps: Optional[Dict[str, str]] = None,
p6_block_maps: Optional[Dict[str, str]] = None,
strides: Optional[List[int]] = None,
anchor_grids: Optional[List[List[float]]] = None,
head_ind: int = 24,
head_name: str = "m",
num_classes: int = 80,
version: str = "r6.0",
use_p6: bool = False,
) -> None:

# Configuration for making the keys consistent
if inner_block_maps is None:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
self.inner_block_maps = inner_block_maps
if layer_block_maps is None:
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}
self.layer_block_maps = layer_block_maps
self.p6_block_maps = p6_block_maps
self.head_ind = head_ind
self.head_name = head_name

# Set model
yolov5_size = get_yolov5_size(depth_multiple, width_multiple)
backbone_name = f"darknet_{yolov5_size}_{version.replace('.', '_')}"

backbone = darknet_pan_backbone(
backbone_name, depth_multiple, width_multiple, version=version, use_p6=use_p6
)
num_anchors = len(anchor_grids[0]) // 2
head = YOLOHead(backbone.out_channels, num_anchors, strides, num_classes)
# Only backbone and head contain parameters inside, so we only wrap them both here.
self.model = ModelWrapper(backbone, head)

def updating(self, state_dict):
# Obtain module state
state_dict = obtain_module_sequential(state_dict)

# Update backbone weights
for name, params in self.model.backbone.body.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, None))

for name, buffers in self.model.backbone.body.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, None))

# Update PAN weights
# Updating P6 weights
if self.p6_block_maps is not None:
for name, params in self.model.backbone.pan.intermediate_blocks.p6.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, self.p6_block_maps))

for name, buffers in self.model.backbone.pan.intermediate_blocks.p6.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, self.p6_block_maps))

# Updating inner_block weights
for name, params in self.model.backbone.pan.inner_blocks.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, self.inner_block_maps))

for name, buffers in self.model.backbone.pan.inner_blocks.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, self.inner_block_maps))

# Updating layer_block weights
for name, params in self.model.backbone.pan.layer_blocks.named_parameters():
params.data.copy_(self.attach_parameters_block(state_dict, name, self.layer_block_maps))

for name, buffers in self.model.backbone.pan.layer_blocks.named_buffers():
buffers.copy_(self.attach_parameters_block(state_dict, name, self.layer_block_maps))

# Update YOLOHead weights
for name, params in self.model.head.named_parameters():
params.data.copy_(self.attach_parameters_heads(state_dict, name))

for name, buffers in self.model.head.named_buffers():
buffers.copy_(self.attach_parameters_heads(state_dict, name))

@staticmethod
def attach_parameters_block(state_dict, name, block_maps=None):
keys = name.split(".")
ind = int(block_maps[keys[0]]) if block_maps else int(keys[0])
return rgetattr(state_dict[ind], keys[1:])

def attach_parameters_heads(self, state_dict, name):
keys = name.split(".")
ind = int(keys[1])
return rgetattr(getattr(state_dict[self.head_ind], self.head_name)[ind], keys[2:])


def rgetattr(obj, attr, *args):
"""
Nested version of getattr.
Ref: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects
"""

def _getattr(obj, attr):
return getattr(obj, attr, *args)

return reduce(_getattr, [obj] + attr)


def obtain_module_sequential(state_dict):
if isinstance(state_dict, nn.Sequential):
return state_dict
else:
return obtain_module_sequential(state_dict.model)
Loading