Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Lightning wrapper into trainer #314

Merged
merged 11 commits into from
Feb 13, 2022
54 changes: 9 additions & 45 deletions test/test_engine.py → test/test_models_yolov5.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
# Copyright (c) 2021, yolort team. All rights reserved.

import pytorch_lightning as pl
import pytest
import torch
from torch import Tensor
from torchvision.io import read_image
from yolort.data import COCOEvaluator, DetectionDataModule, _helper as data_helper
from yolort.data import COCOEvaluator, _helper as data_helper
from yolort.models import yolov5s
from yolort.models.transform import YOLOTransform
from yolort.models.yolo import yolov5_darknet_pan_s_r31
Expand Down Expand Up @@ -75,57 +74,22 @@ def test_train_with_vanilla_module():
assert isinstance(out["objectness"], Tensor)


def test_training_step():
# Setup the DataModule
data_path = "data-bin"
train_dataset = data_helper.get_dataset(data_root=data_path, mode="train")
val_dataset = data_helper.get_dataset(data_root=data_path, mode="val")
data_module = DetectionDataModule(train_dataset, val_dataset, batch_size=16)
# Load model
model = yolov5s()
model.train()
# Trainer
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, data_module)


def test_vanilla_coco_evaluator():
@pytest.mark.parametrize("version, map5095, map50", [("r4.0", 42.5, 65.3)])
def test_vanilla_coco_evaluator(version, map5095, map50):
# Acquire the images and labels from the coco128 dataset
val_dataloader = data_helper.get_dataloader(data_root="data-bin", mode="val")
coco = data_helper.get_coco_api_from_dataset(val_dataloader.dataset)
coco_evaluator = COCOEvaluator(coco)
# Load model
model = yolov5s(upstream_version="r4.0", pretrained=True)
model.eval()
model = yolov5s(upstream_version=version, pretrained=True)
model = model.eval()
for images, targets in val_dataloader:
preds = model(images)
coco_evaluator.update(preds, targets)

results = coco_evaluator.compute()
assert results["AP"] > 37.8
assert results["AP50"] > 59.6


def test_test_epoch_end():
# Acquire the annotation file
data_path = Path("data-bin")
coco128_dirname = "coco128"
data_helper.prepare_coco128(data_path, dirname=coco128_dirname)
annotation_file = data_path / coco128_dirname / "annotations" / "instances_train2017.json"

# Get dataloader to test
val_dataloader = data_helper.get_dataloader(data_root=data_path, mode="val")

# Load model
model = yolov5s(upstream_version="r4.0", pretrained=True, annotation_path=annotation_file)

# test step
trainer = pl.Trainer(max_epochs=1)
trainer.test(model, test_dataloaders=val_dataloader)
# test epoch end
results = model.evaluator.compute()
assert results["AP"] > 37.8
assert results["AP50"] > 59.6
assert results["AP"] > map5095
assert results["AP50"] > map50


def test_predict_with_vanilla_model():
Expand Down
45 changes: 45 additions & 0 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2021, yolort team. All rights reserved.

from pathlib import Path

import pytest
import pytorch_lightning as pl
from yolort.data import DetectionDataModule, _helper as data_helper
from yolort.trainer import DefaultTask


def test_training_step():
# Setup the DataModule
data_path = "data-bin"
train_dataset = data_helper.get_dataset(data_root=data_path, mode="train")
val_dataset = data_helper.get_dataset(data_root=data_path, mode="val")
data_module = DetectionDataModule(train_dataset, val_dataset, batch_size=8)
# Load model
model = DefaultTask(arch="yolov5n")
model = model.train()
# Trainer
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, data_module)


@pytest.mark.parametrize("arch, version, map5095, map50", [("yolov5s", "r4.0", 42.5, 65.3)])
def test_test_epoch_end(arch, version, map5095, map50):
# Acquire the annotation file
data_path = Path("data-bin")
coco128_dirname = "coco128"
data_helper.prepare_coco128(data_path, dirname=coco128_dirname)
annotation_file = data_path / coco128_dirname / "annotations" / "instances_train2017.json"

# Get dataloader to test
val_dataloader = data_helper.get_dataloader(data_root=data_path, mode="val")

# Load model
model = DefaultTask(arch=arch, version=version, pretrained=True, annotation_path=annotation_file)

# test step
trainer = pl.Trainer(max_epochs=1)
trainer.test(model, test_dataloaders=val_dataloader)
# test epoch end
results = model.evaluator.compute()
assert results["AP"] > map5095
assert results["AP50"] > map50
8 changes: 5 additions & 3 deletions yolort/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) 2021, yolort team. All rights reserved.

from typing import Any

from torch import nn
from yolort.v5 import Conv
from yolort.v5.utils.activations import Hardswish, SiLU

from .yolo import YOLO
from .yolo_module import YOLOv5
from .yolov5 import YOLOv5

__all__ = [
"YOLO",
Expand Down Expand Up @@ -187,6 +186,9 @@ def yolov5ts(upstream_version: str = "r4.0", export_friendly: bool = False, **kw


def _export_module_friendly(model):
from yolort.v5 import Conv
from yolort.v5.utils.activations import Hardswish, SiLU

for m in model.modules():
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if isinstance(m, Conv):
Expand Down
12 changes: 0 additions & 12 deletions yolort/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,6 @@

import torch
from torch import nn, Tensor
from torchvision.ops import box_iou


def _evaluate_iou(target, pred):
"""
Evaluate intersection over union (IOU) for target from dataset and
output prediction from model
"""
if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
return torch.tensor(0.0, device=pred["boxes"].device)
return box_iou(target["boxes"], pred["boxes"]).diag().mean()


def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
Expand Down
124 changes: 35 additions & 89 deletions yolort/models/yolo_module.py → yolort/models/yolov5.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,56 @@
# Copyright (c) 2021, yolort team. All rights reserved.
import argparse

import warnings
from pathlib import PosixPath
from typing import Any, Dict, List, Callable, Optional, Tuple, Union
from typing import Any, Dict, List, Callable, Optional, Tuple

import torch
import torchvision
from pytorch_lightning import LightningModule
from torch import nn, Tensor
from torchvision.io import read_image
from yolort.data import COCOEvaluator, contains_any_tensor
from yolort.data import contains_any_tensor

from . import yolo
from ._utils import _evaluate_iou
from .transform import YOLOTransform, _get_shape_onnx
from .yolo import YOLO

__all__ = ["YOLOv5"]


class YOLOv5(LightningModule):
class YOLOv5(nn.Module):
"""
Wrapping the pre-processing (`LetterBox`) into the YOLO models.

Example:

Demo pipeline for YOLOv5 Inference.

.. code-block:: python
from yolort.models import YOLOv5

# Load the yolov5s version 6.0 models
arch = 'yolov5_darknet_pan_s_r60'
model = YOLOv5(arch=arch, pretrained=True, score_thresh=0.35)
model = model.eval()

# Perform inference on an image file
predictions = model.predict('bus.jpg')
# Perform inference on a list of image files
predictions2 = model.predict(['bus.jpg', 'zidane.jpg'])

We also support loading the custom checkpoints trained from ultralytics/yolov5

.. code-block:: python
from yolort.models import YOLOv5

# Your trained checkpoint from ultralytics
checkpoint_path = 'yolov5n.pt'
model = YOLOv5.load_from_yolov5(checkpoint_path, score_thresh=0.35)
model = model.eval()

# Perform inference on an image file
predictions = model.predict('bus.jpg')

Args:
lr (float): The initial learning rate
arch (string): YOLO model architecture. Default: None
model (nn.Module): YOLO model. Default: None
num_classes (int): number of output classes of the model (doesn't including
Expand All @@ -39,13 +65,10 @@ class YOLOv5(LightningModule):
be padded to a minimum rectangle to match `min_size / max_size` and each of its edges
is divisible by `size_divisible` if it is not specified. Default: None
fill_color (int): fill value for padding. Default: 114
annotation_path (Optional[Union[string, PosixPath]]): Path of the COCO annotation file
Default: None.
"""

def __init__(
self,
lr: float = 0.01,
arch: Optional[str] = None,
model: Optional[nn.Module] = None,
num_classes: int = 80,
Expand All @@ -55,13 +78,11 @@ def __init__(
size_divisible: int = 32,
fixed_shape: Optional[Tuple[int, int]] = None,
fill_color: int = 114,
annotation_path: Optional[Union[str, PosixPath]] = None,
**kwargs: Any,
) -> None:

super().__init__()

self.lr = lr
self.arch = arch
self.num_classes = num_classes

Expand All @@ -82,11 +103,6 @@ def __init__(
fill_color=fill_color,
)

# metrics
self.evaluator = None
if annotation_path is not None:
self.evaluator = COCOEvaluator(annotation_path, iou_type="bbox")

# used only on torchscript mode
self._has_warned = False

Expand All @@ -104,7 +120,7 @@ def _forward_impl(
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
like `scores`, `labels` and `boxes`.
"""
# get the original image sizes
original_image_sizes: List[Tuple[int, int]] = []
Expand Down Expand Up @@ -173,50 +189,6 @@ def forward(
"""
return self._forward_impl(inputs, targets)

def training_step(self, batch, batch_idx):
"""
The training step.
"""
loss_dict = self._forward_impl(*batch)
loss = sum(loss_dict.values())
self.log_dict(loss_dict, on_step=True, on_epoch=True, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
images, targets = batch
# fasterrcnn takes only images for eval() mode
preds = self._forward_impl(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, preds)]).mean()
outs = {"val_iou": iou}
self.log_dict(outs, on_step=True, on_epoch=True, prog_bar=True)
return outs

def validation_epoch_end(self, outs):
avg_iou = torch.stack([o["val_iou"] for o in outs]).mean()
self.log("avg_val_iou", avg_iou)

def test_step(self, batch, batch_idx):
"""
The test step.
"""
images, targets = batch
images = list(image.to(next(self.parameters()).device) for image in images)
preds = self._forward_impl(images)
results = self.evaluator(preds, targets)
# log step metric
self.log("eval_step", results, prog_bar=True, on_step=True)

def test_epoch_end(self, outputs):
return self.log("coco_eval", self.evaluator.compute())

def configure_optimizers(self):
return torch.optim.SGD(
self.model.parameters(),
lr=self.lr,
momentum=0.9,
weight_decay=5e-4,
)

@torch.no_grad()
def predict(self, x: Any, image_loader: Optional[Callable] = None) -> List[Dict[str, Tensor]]:
"""
Expand Down Expand Up @@ -278,32 +250,6 @@ def collate_images(self, samples: Any, image_loader: Callable) -> List[Tensor]:
"samples should be either a tensor, list of tensors, a image path or list of image paths."
)

@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--arch", default="yolov5_darknet_pan_s_r40", help="model architecture")
parser.add_argument(
"--pretrained",
action="store_true",
help="Use pre-trained models from the modelzoo",
)
parser.add_argument(
"--lr",
default=0.01,
type=float,
help="initial learning rate, 0.01 is the default value for training "
"on 8 gpus and 2 images_per_gpu",
)
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument(
"--weight-decay",
default=5e-4,
type=float,
metavar="W",
help="weight decay (default: 5e-4)",
)
return parser

@classmethod
def load_from_yolov5(
cls,
Expand Down
5 changes: 5 additions & 0 deletions yolort/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2021, yolort team. All rights reserved.

from .lightning_task import DefaultTask

__all__ = ["DefaultTask"]
Loading