diff --git a/losses.py b/losses.py new file mode 100644 index 0000000..940610a --- /dev/null +++ b/losses.py @@ -0,0 +1,138 @@ +"""Loss utilities for RoRD training.""" +from __future__ import annotations + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _augment_homography_matrix(h_2x3: torch.Tensor) -> torch.Tensor: + """Append the third row [0, 0, 1] to build a full 3x3 homography.""" + if h_2x3.dim() != 3 or h_2x3.size(1) != 2 or h_2x3.size(2) != 3: + raise ValueError("Expected homography with shape (B, 2, 3)") + + batch_size = h_2x3.size(0) + device = h_2x3.device + bottom_row = torch.tensor([0.0, 0.0, 1.0], device=device, dtype=h_2x3.dtype) + bottom_row = bottom_row.view(1, 1, 3).expand(batch_size, -1, -1) + return torch.cat([h_2x3, bottom_row], dim=1) + + +def warp_feature_map(feature_map: torch.Tensor, h_inv: torch.Tensor) -> torch.Tensor: + """Warp feature map according to inverse homography.""" + return F.grid_sample( + feature_map, + F.affine_grid(h_inv, feature_map.size(), align_corners=False), + align_corners=False, + ) + + +def compute_detection_loss( + det_original: torch.Tensor, + det_rotated: torch.Tensor, + h: torch.Tensor, +) -> torch.Tensor: + """Binary cross-entropy + smooth L1 detection loss.""" + h_full = _augment_homography_matrix(h) + h_inv = torch.inverse(h_full)[:, :2, :] + warped_det = warp_feature_map(det_rotated, h_inv) + + bce_loss = F.binary_cross_entropy(det_original, warped_det) + smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det) + return bce_loss + 0.1 * smooth_l1_loss + + +def compute_description_loss( + desc_original: torch.Tensor, + desc_rotated: torch.Tensor, + h: torch.Tensor, + margin: float = 1.0, +) -> torch.Tensor: + """Triplet-style descriptor loss with Manhattan-aware sampling.""" + batch_size, channels, height, width = desc_original.size() + num_samples = 200 + + grid_side = int(math.sqrt(num_samples)) + h_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device) + w_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device) + + manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)]) + manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords]) + manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1) + manhattan_coords = manhattan_coords.unsqueeze(0).repeat(batch_size, 1, 1) + + anchor = F.grid_sample( + desc_original, + manhattan_coords.unsqueeze(1), + align_corners=False, + ).squeeze(2).transpose(1, 2) + + coords_hom = torch.cat( + [manhattan_coords, torch.ones(batch_size, manhattan_coords.size(1), 1, device=desc_original.device)], + dim=2, + ) + + h_full = _augment_homography_matrix(h) + h_inv = torch.inverse(h_full) + coords_transformed = (coords_hom @ h_inv.transpose(1, 2))[:, :, :2] + + positive = F.grid_sample( + desc_rotated, + coords_transformed.unsqueeze(1), + align_corners=False, + ).squeeze(2).transpose(1, 2) + + negative_list = [] + if manhattan_coords.size(1) > 0: + angles = [0, 90, 180, 270] + for angle in angles: + if angle == 0: + continue + theta = torch.tensor(angle * math.pi / 180.0, device=desc_original.device) + cos_t = torch.cos(theta) + sin_t = torch.sin(theta) + rot = torch.stack( + [ + torch.stack([cos_t, -sin_t]), + torch.stack([sin_t, cos_t]), + ] + ) + rotated_coords = manhattan_coords @ rot.T + negative_list.append(rotated_coords) + + if negative_list: + neg_coords = torch.stack(negative_list, dim=1).reshape(batch_size, -1, 2) + negative_candidates = F.grid_sample( + desc_rotated, + neg_coords.unsqueeze(1), + align_corners=False, + ).squeeze(2).transpose(1, 2) + + anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1) + negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1) + manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3) + + k = max(anchor.size(1) // 2, 1) + hard_indices = torch.topk(manhattan_dist, k=k, largest=False)[1] + idx_expand = hard_indices.unsqueeze(-1).expand(-1, -1, -1, negative_candidates.size(2)) + negative = torch.gather(negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1), 2, idx_expand) + negative = negative.mean(dim=2) + else: + negative = torch.zeros_like(anchor) + + triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean') + geometric_triplet = triplet_loss(anchor, positive, negative) + + manhattan_loss = 0.0 + for i in range(anchor.size(1)): + anchor_norm = F.normalize(anchor[:, i], p=2, dim=1) + positive_norm = F.normalize(positive[:, i], p=2, dim=1) + cos_sim = torch.sum(anchor_norm * positive_norm, dim=1) + manhattan_loss += torch.mean(1 - cos_sim) + + manhattan_loss = manhattan_loss / max(anchor.size(1), 1) + sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive)) + binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive))) + + return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss diff --git a/utils/config_loader.py b/utils/config_loader.py new file mode 100644 index 0000000..b9b1cf8 --- /dev/null +++ b/utils/config_loader.py @@ -0,0 +1,23 @@ +"""Configuration loading utilities using OmegaConf.""" +from __future__ import annotations + +from pathlib import Path +from typing import Union + +from omegaconf import DictConfig, OmegaConf + + +def load_config(config_path: Union[str, Path]) -> DictConfig: + """Load a YAML configuration file into a DictConfig.""" + path = Path(config_path) + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {path}") + return OmegaConf.load(path) + + +def to_absolute_path(path_str: str, base_dir: Union[str, Path]) -> Path: + """Resolve a possibly relative path against the configuration file directory.""" + path = Path(path_str).expanduser() + if path.is_absolute(): + return path.resolve() + return (Path(base_dir) / path).resolve()