Source code for utils.anchors

"""Extracts multiple region proposals on the input image"""

import torch
from torch import nn


[docs] class AnchorGenerator(nn.Module): """Extracts multiple region proposals on the input image""" def __init__( self, sizes: torch.Tensor, ratios: torch.Tensor, step: int = 1, ) -> None: """ :param sizes: Box scales (0,1] = S'/S. :type sizes: torch.Tensor :param ratios: Ratio of width to height of boxes (w/h). :type ratios: torch.Tensor :param step: Box per pixel. Defaults to 1. :type step: int, optional """ super().__init__() self.step = step self.sizes = nn.Parameter(sizes, requires_grad=False) self.ratios = nn.Parameter(ratios, requires_grad=False)
[docs] def __call__(self, X: torch.Tensor) -> torch.Tensor: """Generate anchor boxes with different shapes centered on each pixel. :param X: Feature map. :type X: torch.Tensor :return: Tensor with hypotheses. Shape: [anchor, 4]. Data: (xlu, ylu, xrd, yrd). :rtype: torch.Tensor """ if not hasattr(self, "anchors"): self._cal_anchors(X) return self.anchors
def _cal_anchors(self, X: torch.Tensor) -> None: in_height, in_width = X.shape[-2:] device, num_sizes, num_ratios = X.device, len(self.sizes), len(self.ratios) boxes_per_pixel = num_sizes * num_ratios # Offsets are required to move the anchor to the center of a pixel. Since # a pixel has height=1 and width=1, we choose to offset our centers by 0.5 offset_h, offset_w = 0.5, 0.5 steps_h = 1.0 / in_height # Scaled steps in y axis steps_w = 1.0 / in_width # Scaled steps in x axisr # Generate all center points for the anchor boxes center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing="ij") shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1) # Generate `boxes_per_pixel` number of heights and widths that are later # used to create anchor box corner coordinates (xmin, xmax, ymin, ymax) w = ( torch.cat([self.sizes * ratio for ratio in self.ratios]) * in_height / in_width ) h = ( torch.cat([self.sizes / ratio for ratio in self.ratios]) * in_width / in_height ) # Divide by 2 to get half height and half width anchor_manipulations = ( torch.stack((-w, -h, w, h)).T.repeat(in_height * in_width, 1) / 2 ) # Each center point will have `boxes_per_pixel` number of anchor boxes, so # generate a grid of all anchor box centers with `boxes_per_pixel` repeats out_grid = torch.stack( [shift_x, shift_y, shift_x, shift_y], dim=1 ).repeat_interleave(boxes_per_pixel, dim=0) self.anchors = out_grid + anchor_manipulations