139 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			139 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Loss utilities for RoRD training."""
 | |
| from __future__ import annotations
 | |
| 
 | |
| import math
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| 
 | |
| 
 | |
| def _augment_homography_matrix(h_2x3: torch.Tensor) -> torch.Tensor:
 | |
|     """Append the third row [0, 0, 1] to build a full 3x3 homography."""
 | |
|     if h_2x3.dim() != 3 or h_2x3.size(1) != 2 or h_2x3.size(2) != 3:
 | |
|         raise ValueError("Expected homography with shape (B, 2, 3)")
 | |
| 
 | |
|     batch_size = h_2x3.size(0)
 | |
|     device = h_2x3.device
 | |
|     bottom_row = torch.tensor([0.0, 0.0, 1.0], device=device, dtype=h_2x3.dtype)
 | |
|     bottom_row = bottom_row.view(1, 1, 3).expand(batch_size, -1, -1)
 | |
|     return torch.cat([h_2x3, bottom_row], dim=1)
 | |
| 
 | |
| 
 | |
| def warp_feature_map(feature_map: torch.Tensor, h_inv: torch.Tensor) -> torch.Tensor:
 | |
|     """Warp feature map according to inverse homography."""
 | |
|     return F.grid_sample(
 | |
|         feature_map,
 | |
|         F.affine_grid(h_inv, feature_map.size(), align_corners=False),
 | |
|         align_corners=False,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def compute_detection_loss(
 | |
|     det_original: torch.Tensor,
 | |
|     det_rotated: torch.Tensor,
 | |
|     h: torch.Tensor,
 | |
| ) -> torch.Tensor:
 | |
|     """Binary cross-entropy + smooth L1 detection loss."""
 | |
|     h_full = _augment_homography_matrix(h)
 | |
|     h_inv = torch.inverse(h_full)[:, :2, :]
 | |
|     warped_det = warp_feature_map(det_rotated, h_inv)
 | |
| 
 | |
|     bce_loss = F.binary_cross_entropy(det_original, warped_det)
 | |
|     smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det)
 | |
|     return bce_loss + 0.1 * smooth_l1_loss
 | |
| 
 | |
| 
 | |
| def compute_description_loss(
 | |
|     desc_original: torch.Tensor,
 | |
|     desc_rotated: torch.Tensor,
 | |
|     h: torch.Tensor,
 | |
|     margin: float = 1.0,
 | |
| ) -> torch.Tensor:
 | |
|     """Triplet-style descriptor loss with Manhattan-aware sampling."""
 | |
|     batch_size, channels, height, width = desc_original.size()
 | |
|     num_samples = 200
 | |
| 
 | |
|     grid_side = int(math.sqrt(num_samples))
 | |
|     h_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device)
 | |
|     w_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device)
 | |
| 
 | |
|     manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)])
 | |
|     manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords])
 | |
|     manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1)
 | |
|     manhattan_coords = manhattan_coords.unsqueeze(0).repeat(batch_size, 1, 1)
 | |
| 
 | |
|     anchor = F.grid_sample(
 | |
|         desc_original,
 | |
|         manhattan_coords.unsqueeze(1),
 | |
|         align_corners=False,
 | |
|     ).squeeze(2).transpose(1, 2)
 | |
| 
 | |
|     coords_hom = torch.cat(
 | |
|         [manhattan_coords, torch.ones(batch_size, manhattan_coords.size(1), 1, device=desc_original.device)],
 | |
|         dim=2,
 | |
|     )
 | |
| 
 | |
|     h_full = _augment_homography_matrix(h)
 | |
|     h_inv = torch.inverse(h_full)
 | |
|     coords_transformed = (coords_hom @ h_inv.transpose(1, 2))[:, :, :2]
 | |
| 
 | |
|     positive = F.grid_sample(
 | |
|         desc_rotated,
 | |
|         coords_transformed.unsqueeze(1),
 | |
|         align_corners=False,
 | |
|     ).squeeze(2).transpose(1, 2)
 | |
| 
 | |
|     negative_list = []
 | |
|     if manhattan_coords.size(1) > 0:
 | |
|         angles = [0, 90, 180, 270]
 | |
|         for angle in angles:
 | |
|             if angle == 0:
 | |
|                 continue
 | |
|             theta = torch.tensor(angle * math.pi / 180.0, device=desc_original.device)
 | |
|             cos_t = torch.cos(theta)
 | |
|             sin_t = torch.sin(theta)
 | |
|             rot = torch.stack(
 | |
|                 [
 | |
|                     torch.stack([cos_t, -sin_t]),
 | |
|                     torch.stack([sin_t, cos_t]),
 | |
|                 ]
 | |
|             )
 | |
|             rotated_coords = manhattan_coords @ rot.T
 | |
|             negative_list.append(rotated_coords)
 | |
| 
 | |
|     if negative_list:
 | |
|         neg_coords = torch.stack(negative_list, dim=1).reshape(batch_size, -1, 2)
 | |
|         negative_candidates = F.grid_sample(
 | |
|             desc_rotated,
 | |
|             neg_coords.unsqueeze(1),
 | |
|             align_corners=False,
 | |
|         ).squeeze(2).transpose(1, 2)
 | |
| 
 | |
|         anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
 | |
|         negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
 | |
|         manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3)
 | |
| 
 | |
|         k = max(anchor.size(1) // 2, 1)
 | |
|         hard_indices = torch.topk(manhattan_dist, k=k, largest=False)[1]
 | |
|         idx_expand = hard_indices.unsqueeze(-1).expand(-1, -1, -1, negative_candidates.size(2))
 | |
|         negative = torch.gather(negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1), 2, idx_expand)
 | |
|         negative = negative.mean(dim=2)
 | |
|     else:
 | |
|         negative = torch.zeros_like(anchor)
 | |
| 
 | |
|     triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean')
 | |
|     geometric_triplet = triplet_loss(anchor, positive, negative)
 | |
| 
 | |
|     manhattan_loss = 0.0
 | |
|     for i in range(anchor.size(1)):
 | |
|         anchor_norm = F.normalize(anchor[:, i], p=2, dim=1)
 | |
|         positive_norm = F.normalize(positive[:, i], p=2, dim=1)
 | |
|         cos_sim = torch.sum(anchor_norm * positive_norm, dim=1)
 | |
|         manhattan_loss += torch.mean(1 - cos_sim)
 | |
| 
 | |
|     manhattan_loss = manhattan_loss / max(anchor.size(1), 1)
 | |
|     sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive))
 | |
|     binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive)))
 | |
| 
 | |
|     return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss
 | 
