taishi-addtodolist #1
138
losses.py
Normal file
138
losses.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Loss utilities for RoRD training."""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _augment_homography_matrix(h_2x3: torch.Tensor) -> torch.Tensor:
|
||||
"""Append the third row [0, 0, 1] to build a full 3x3 homography."""
|
||||
if h_2x3.dim() != 3 or h_2x3.size(1) != 2 or h_2x3.size(2) != 3:
|
||||
raise ValueError("Expected homography with shape (B, 2, 3)")
|
||||
|
||||
batch_size = h_2x3.size(0)
|
||||
device = h_2x3.device
|
||||
bottom_row = torch.tensor([0.0, 0.0, 1.0], device=device, dtype=h_2x3.dtype)
|
||||
bottom_row = bottom_row.view(1, 1, 3).expand(batch_size, -1, -1)
|
||||
return torch.cat([h_2x3, bottom_row], dim=1)
|
||||
|
||||
|
||||
def warp_feature_map(feature_map: torch.Tensor, h_inv: torch.Tensor) -> torch.Tensor:
|
||||
"""Warp feature map according to inverse homography."""
|
||||
return F.grid_sample(
|
||||
feature_map,
|
||||
F.affine_grid(h_inv, feature_map.size(), align_corners=False),
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
|
||||
def compute_detection_loss(
|
||||
det_original: torch.Tensor,
|
||||
det_rotated: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Binary cross-entropy + smooth L1 detection loss."""
|
||||
h_full = _augment_homography_matrix(h)
|
||||
h_inv = torch.inverse(h_full)[:, :2, :]
|
||||
warped_det = warp_feature_map(det_rotated, h_inv)
|
||||
|
||||
bce_loss = F.binary_cross_entropy(det_original, warped_det)
|
||||
smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det)
|
||||
return bce_loss + 0.1 * smooth_l1_loss
|
||||
|
||||
|
||||
def compute_description_loss(
|
||||
desc_original: torch.Tensor,
|
||||
desc_rotated: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
margin: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Triplet-style descriptor loss with Manhattan-aware sampling."""
|
||||
batch_size, channels, height, width = desc_original.size()
|
||||
num_samples = 200
|
||||
|
||||
grid_side = int(math.sqrt(num_samples))
|
||||
h_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device)
|
||||
w_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device)
|
||||
|
||||
manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)])
|
||||
manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords])
|
||||
manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1)
|
||||
manhattan_coords = manhattan_coords.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
|
||||
anchor = F.grid_sample(
|
||||
desc_original,
|
||||
manhattan_coords.unsqueeze(1),
|
||||
align_corners=False,
|
||||
).squeeze(2).transpose(1, 2)
|
||||
|
||||
coords_hom = torch.cat(
|
||||
[manhattan_coords, torch.ones(batch_size, manhattan_coords.size(1), 1, device=desc_original.device)],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
h_full = _augment_homography_matrix(h)
|
||||
h_inv = torch.inverse(h_full)
|
||||
coords_transformed = (coords_hom @ h_inv.transpose(1, 2))[:, :, :2]
|
||||
|
||||
positive = F.grid_sample(
|
||||
desc_rotated,
|
||||
coords_transformed.unsqueeze(1),
|
||||
align_corners=False,
|
||||
).squeeze(2).transpose(1, 2)
|
||||
|
||||
negative_list = []
|
||||
if manhattan_coords.size(1) > 0:
|
||||
angles = [0, 90, 180, 270]
|
||||
for angle in angles:
|
||||
if angle == 0:
|
||||
continue
|
||||
theta = torch.tensor(angle * math.pi / 180.0, device=desc_original.device)
|
||||
cos_t = torch.cos(theta)
|
||||
sin_t = torch.sin(theta)
|
||||
rot = torch.stack(
|
||||
[
|
||||
torch.stack([cos_t, -sin_t]),
|
||||
torch.stack([sin_t, cos_t]),
|
||||
]
|
||||
)
|
||||
rotated_coords = manhattan_coords @ rot.T
|
||||
negative_list.append(rotated_coords)
|
||||
|
||||
if negative_list:
|
||||
neg_coords = torch.stack(negative_list, dim=1).reshape(batch_size, -1, 2)
|
||||
negative_candidates = F.grid_sample(
|
||||
desc_rotated,
|
||||
neg_coords.unsqueeze(1),
|
||||
align_corners=False,
|
||||
).squeeze(2).transpose(1, 2)
|
||||
|
||||
anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
|
||||
negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
|
||||
manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3)
|
||||
|
||||
k = max(anchor.size(1) // 2, 1)
|
||||
hard_indices = torch.topk(manhattan_dist, k=k, largest=False)[1]
|
||||
idx_expand = hard_indices.unsqueeze(-1).expand(-1, -1, -1, negative_candidates.size(2))
|
||||
negative = torch.gather(negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1), 2, idx_expand)
|
||||
negative = negative.mean(dim=2)
|
||||
else:
|
||||
negative = torch.zeros_like(anchor)
|
||||
|
||||
triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean')
|
||||
geometric_triplet = triplet_loss(anchor, positive, negative)
|
||||
|
||||
manhattan_loss = 0.0
|
||||
for i in range(anchor.size(1)):
|
||||
anchor_norm = F.normalize(anchor[:, i], p=2, dim=1)
|
||||
positive_norm = F.normalize(positive[:, i], p=2, dim=1)
|
||||
cos_sim = torch.sum(anchor_norm * positive_norm, dim=1)
|
||||
manhattan_loss += torch.mean(1 - cos_sim)
|
||||
|
||||
manhattan_loss = manhattan_loss / max(anchor.size(1), 1)
|
||||
sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive))
|
||||
binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive)))
|
||||
|
||||
return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss
|
||||
23
utils/config_loader.py
Normal file
23
utils/config_loader.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Configuration loading utilities using OmegaConf."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> DictConfig:
|
||||
"""Load a YAML configuration file into a DictConfig."""
|
||||
path = Path(config_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {path}")
|
||||
return OmegaConf.load(path)
|
||||
|
||||
|
||||
def to_absolute_path(path_str: str, base_dir: Union[str, Path]) -> Path:
|
||||
"""Resolve a possibly relative path against the configuration file directory."""
|
||||
path = Path(path_str).expanduser()
|
||||
if path.is_absolute():
|
||||
return path.resolve()
|
||||
return (Path(base_dir) / path).resolve()
|
||||
Reference in New Issue
Block a user