Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing build_targets in SetCriterion #143

Merged
merged 16 commits into from
Sep 9, 2021
Prev Previous commit
Fixing jit annotations for PT1.7
  • Loading branch information
zhiqwang committed Sep 9, 2021
commit 609399fd6c84d9b3fc818bb7cc34a14dbc3587a3
10 changes: 6 additions & 4 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
@@ -121,7 +121,9 @@ def __init__(

# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
# positive, negative BCE targets
self.cp, self.cn = det_utils.smooth_binary_cross_entropy(eps=label_smoothing)
smooth_bce = det_utils.smooth_binary_cross_entropy(eps=label_smoothing)
self.smooth_pos = smooth_bce[0]
self.smooth_neg = smooth_bce[1]

# Parameters for training
self.gr = 1.0
@@ -177,7 +179,7 @@ def __call__(
loss_box += (1.0 - iou).mean() # iou loss

# Objectness
score_iou = iou.detach().clamp(0).type(target_obj.dtype)
score_iou = iou.detach().clamp(0).to(dtype=target_obj.dtype)
if self.sort_obj_iou:
sort_id = torch.argsort(score_iou)
b, a, gj, gi = b[sort_id], a[sort_id], gj[sort_id], gi[sort_id]
@@ -186,8 +188,8 @@ def __call__(

# Classification
if self.num_classes > 1: # cls loss (only if multiple classes)
t = torch.full_like(pred_logits_subset[:, 5:], self.cn, device=device) # targets
t[torch.arange(num_targets), target_cls[i]] = self.cp
t = torch.full_like(pred_logits_subset[:, 5:], self.smooth_neg, device=device) # targets
t[torch.arange(num_targets), target_cls[i]] = self.smooth_pos
loss_cls += F.binary_cross_entropy_with_logits(
pred_logits_subset[:, 5:], t, pos_weight=pos_weight_cls)