Source code for utils.roi

"""Label anchor boxes using ground-truth bounding boxes"""

import torch
from tqdm import tqdm
import utils.box as box


[docs] class RoI: """Label anchor boxes using ground-truth bounding boxes""" def __init__(self, iou_threshold=0.5) -> None: """ :param iou_threshold: Minimum acceptable iou. Defaults to 0.5. :type iou_threshold: float, optional """ self.iou_threshold = iou_threshold
[docs] def __call__( self, anchors: torch.Tensor, labels: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Label anchor boxes using ground-truth bounding boxes :param anchors: Shape [anchor, 4] :type anchors: torch.Tensor :param labels: Shape [batch, gt_boxes, 5] One label contains (class, luw, luh, rdw, rdh) :type labels: torch.Tensor :return: List of 3 tensors: 1. Ground truth offsets for each box. Shape [batch, anchor, 4]. 2. Box mask. Shape [batch, anchor, 4]. (0)*4 for background, (1)*4 for object. 3. Class of each box (0 - background). Shape [batch, anchor]. :rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor] """ batch_size = labels.shape[0] device, num_anchors = anchors.device, anchors.shape[0] batch_offset, batch_mask, batch_class_labels = [], [], [] for i in range(batch_size): label = labels[i, :, :] anchors_bbox_map = self._assign_anchor_to_box(label[:, 1:], anchors) bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(1, 4) # Initialize class labels and assigned bounding box coordinates with zeros class_labels = torch.zeros(num_anchors, dtype=torch.long, device=device) assigned_bb = torch.zeros( (num_anchors, 4), dtype=torch.float32, device=device ) # Label classes of anchor boxes using their assigned ground-truth # bounding boxes. If an anchor box is not assigned any, we label its # class as background (the value remains zero) indices_true = torch.nonzero(anchors_bbox_map >= 0) bb_idx = anchors_bbox_map[indices_true] class_labels[indices_true] = label[bb_idx, 0].long() + 1 assigned_bb[indices_true] = label[bb_idx, 1:] # Offset transformation offset = box.offset_boxes(anchors, assigned_bb) * bbox_mask batch_offset.append(offset) batch_mask.append(bbox_mask) batch_class_labels.append(class_labels) bbox_offset = torch.stack(batch_offset) bbox_mask = torch.stack(batch_mask) class_labels = torch.stack(batch_class_labels) return bbox_offset, bbox_mask, class_labels
def _assign_anchor_to_box( self, ground_truth: torch.Tensor, anchors: torch.Tensor ) -> torch.Tensor: """Assign closest ground-truth bounding boxes to anchor boxes see https://d2l.ai/chapter_computer-vision/anchor.html#assigning-ground-truth-bounding-boxes-to-anchor-boxes :param ground_truth: The ground-truth bounding boxes [gt_box, 4] - ulw, ulh, drw, drh :type ground_truth: torch.Tensor :param anchors: Anchors boxes [anchor, 4] - ulw, ulh, drw, drh :type anchors: torch.Tensor :return: Tensor with ground truth box indices [anchor] :rtype: torch.Tensor """ num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0] # The Jaccard index measures the similarity between two sets [num_anchors, num_gt_box] jaccard = box.box_iou(anchors, ground_truth) # Initialize the tensor to hold the assigned ground-truth bounding box for each anchor anchors_box_map = torch.full( (num_anchors,), -1, dtype=torch.long, device=anchors.device ) # Assign ground-truth bounding boxes according to the threshold max_ious, indices = torch.max(jaccard, dim=1) # Indexes of non-empty boxes mask = max_ious >= self.iou_threshold anc_i = torch.nonzero(mask).reshape(-1) if len(anc_i) == 0: tqdm.write("[WARN]: There is no suitable anchor") box_j = indices[mask] # Each anchor is assigned a gt_box with the highest iou if it is greater than the threshold anchors_box_map[anc_i] = box_j # For each gt_box we assign an anchor with maximum iou col_discard = torch.full((num_anchors,), -1) row_discard = torch.full((num_gt_boxes,), -1) for _ in range(num_gt_boxes): max_idx = torch.argmax(jaccard) # Find the largest IoU box_idx = (max_idx % num_gt_boxes).long() anc_idx = (max_idx / num_gt_boxes).long() anchors_box_map[anc_idx] = box_idx jaccard[:, box_idx] = col_discard jaccard[anc_idx, :] = row_discard return anchors_box_map