From 8c9926c815f1789f845157f97566c61c45c91156 Mon Sep 17 00:00:00 2001 From: Jiao77 Date: Thu, 25 Sep 2025 20:20:24 +0800 Subject: [PATCH 1/3] complete code struction update --- config.py | 51 +++---- configs/base_config.yaml | 25 ++++ data/__init__.py | 1 + data/ic_dataset.py | 97 +++++++++++++- docs/feature_work.md | 166 +++++++++++++++++++++++ evaluate.py | 48 ++++--- match.py | 55 +++++--- pyproject.toml | 1 + train.py | 279 ++++++++------------------------------- uv.lock | 47 +++++++ 10 files changed, 480 insertions(+), 290 deletions(-) create mode 100644 configs/base_config.yaml create mode 100644 docs/feature_work.md diff --git a/config.py b/config.py index 0f43a0a..621b64b 100644 --- a/config.py +++ b/config.py @@ -1,29 +1,34 @@ -# config.py +"""Legacy config shim loading values from YAML.""" +from __future__ import annotations + +from pathlib import Path + +from omegaconf import OmegaConf + + +_BASE_CONFIG_PATH = Path(__file__).resolve().parent / "configs" / "base_config.yaml" +_CFG = OmegaConf.load(_BASE_CONFIG_PATH) # --- 训练参数 --- -LEARNING_RATE = 5e-5 # 降低学习率,提高训练稳定性 -BATCH_SIZE = 8 # 增加批次大小,提高训练效率 -NUM_EPOCHS = 50 # 增加训练轮数 -PATCH_SIZE = 256 -# (优化) 训练时尺度抖动范围 - 缩小范围提高稳定性 -SCALE_JITTER_RANGE = (0.8, 1.2) +LEARNING_RATE = float(_CFG.training.learning_rate) +BATCH_SIZE = int(_CFG.training.batch_size) +NUM_EPOCHS = int(_CFG.training.num_epochs) +PATCH_SIZE = int(_CFG.training.patch_size) +SCALE_JITTER_RANGE = tuple(float(x) for x in _CFG.training.scale_jitter_range) # --- 匹配与评估参数 --- -KEYPOINT_THRESHOLD = 0.5 -RANSAC_REPROJ_THRESHOLD = 5.0 -MIN_INLIERS = 15 -IOU_THRESHOLD = 0.5 -# (新增) 推理时模板匹配的图像金字塔尺度 -PYRAMID_SCALES = [0.75, 1.0, 1.5] -# (新增) 推理时处理大版图的滑动窗口参数 -INFERENCE_WINDOW_SIZE = 1024 -INFERENCE_STRIDE = 768 # 小于INFERENCE_WINDOW_SIZE以保证重叠 +KEYPOINT_THRESHOLD = float(_CFG.matching.keypoint_threshold) +RANSAC_REPROJ_THRESHOLD = float(_CFG.matching.ransac_reproj_threshold) +MIN_INLIERS = int(_CFG.matching.min_inliers) +PYRAMID_SCALES = [float(s) for s in _CFG.matching.pyramid_scales] +INFERENCE_WINDOW_SIZE = int(_CFG.matching.inference_window_size) +INFERENCE_STRIDE = int(_CFG.matching.inference_stride) +IOU_THRESHOLD = float(_CFG.evaluation.iou_threshold) # --- 文件路径 --- -# (路径保持不变, 请根据您的环境修改) -LAYOUT_DIR = 'path/to/layouts' -SAVE_DIR = 'path/to/save' -VAL_IMG_DIR = 'path/to/val/images' -VAL_ANN_DIR = 'path/to/val/annotations' -TEMPLATE_DIR = 'path/to/templates' -MODEL_PATH = 'path/to/save/model_final.pth' \ No newline at end of file +LAYOUT_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.layout_dir).resolve()) if not Path(_CFG.paths.layout_dir).is_absolute() else _CFG.paths.layout_dir +SAVE_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.save_dir).resolve()) if not Path(_CFG.paths.save_dir).is_absolute() else _CFG.paths.save_dir +VAL_IMG_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.val_img_dir).resolve()) if not Path(_CFG.paths.val_img_dir).is_absolute() else _CFG.paths.val_img_dir +VAL_ANN_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.val_ann_dir).resolve()) if not Path(_CFG.paths.val_ann_dir).is_absolute() else _CFG.paths.val_ann_dir +TEMPLATE_DIR = str((_BASE_CONFIG_PATH.parent / _CFG.paths.template_dir).resolve()) if not Path(_CFG.paths.template_dir).is_absolute() else _CFG.paths.template_dir +MODEL_PATH = str((_BASE_CONFIG_PATH.parent / _CFG.paths.model_path).resolve()) if not Path(_CFG.paths.model_path).is_absolute() else _CFG.paths.model_path \ No newline at end of file diff --git a/configs/base_config.yaml b/configs/base_config.yaml new file mode 100644 index 0000000..70a6512 --- /dev/null +++ b/configs/base_config.yaml @@ -0,0 +1,25 @@ +training: + learning_rate: 5.0e-5 + batch_size: 8 + num_epochs: 50 + patch_size: 256 + scale_jitter_range: [0.8, 1.2] + +matching: + keypoint_threshold: 0.5 + ransac_reproj_threshold: 5.0 + min_inliers: 15 + pyramid_scales: [0.75, 1.0, 1.5] + inference_window_size: 1024 + inference_stride: 768 + +evaluation: + iou_threshold: 0.5 + +paths: + layout_dir: "path/to/layouts" + save_dir: "path/to/save" + val_img_dir: "path/to/val/images" + val_ann_dir: "path/to/val/annotations" + template_dir: "path/to/templates" + model_path: "path/to/save/model_final.pth" diff --git a/data/__init__.py b/data/__init__.py index e69de29..a2821b0 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -0,0 +1 @@ +from .ic_dataset import ICLayoutDataset, ICLayoutTrainingDataset diff --git a/data/ic_dataset.py b/data/ic_dataset.py index 155a6c8..049ab54 100644 --- a/data/ic_dataset.py +++ b/data/ic_dataset.py @@ -1,7 +1,12 @@ import os +import json +from typing import Tuple + +import cv2 +import numpy as np +import torch from PIL import Image from torch.utils.data import Dataset -import json class ICLayoutDataset(Dataset): def __init__(self, image_dir, annotation_dir=None, transform=None): @@ -53,4 +58,92 @@ class ICLayoutDataset(Dataset): with open(ann_path, 'r') as f: annotation = json.load(f) - return image, annotation \ No newline at end of file + return image, annotation + + +class ICLayoutTrainingDataset(Dataset): + """自监督训练用的 IC 版图数据集,带数据增强与几何配准标签。""" + + def __init__( + self, + image_dir: str, + patch_size: int = 256, + transform=None, + scale_range: Tuple[float, float] = (1.0, 1.0), + ) -> None: + self.image_dir = image_dir + self.image_paths = [ + os.path.join(image_dir, f) + for f in os.listdir(image_dir) + if f.endswith('.png') + ] + self.patch_size = patch_size + self.transform = transform + self.scale_range = scale_range + + def __len__(self) -> int: + return len(self.image_paths) + + def __getitem__(self, index: int): + img_path = self.image_paths[index] + image = Image.open(img_path).convert('L') + width, height = image.size + + # 随机尺度抖动 + scale = float(np.random.uniform(self.scale_range[0], self.scale_range[1])) + crop_size = int(self.patch_size / max(scale, 1e-6)) + crop_size = min(crop_size, width, height) + + if crop_size <= 0: + raise ValueError("crop_size must be positive; check scale_range configuration") + + x = np.random.randint(0, max(width - crop_size + 1, 1)) + y = np.random.randint(0, max(height - crop_size + 1, 1)) + patch = image.crop((x, y, x + crop_size, y + crop_size)) + patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS) + + # 亮度/对比度增强 + if np.random.random() < 0.5: + brightness_factor = np.random.uniform(0.8, 1.2) + patch = patch.point(lambda px: int(np.clip(px * brightness_factor, 0, 255))) + + if np.random.random() < 0.5: + contrast_factor = np.random.uniform(0.8, 1.2) + patch = patch.point(lambda px: int(np.clip(((px - 128) * contrast_factor) + 128, 0, 255))) + + if np.random.random() < 0.3: + patch_np = np.array(patch, dtype=np.float32) + noise = np.random.normal(0, 5, patch_np.shape) + patch_np = np.clip(patch_np + noise, 0, 255) + patch = Image.fromarray(patch_np.astype(np.uint8)) + + patch_np_uint8 = np.array(patch) + + # 随机旋转与镜像(8个离散变换) + theta_deg = int(np.random.choice([0, 90, 180, 270])) + is_mirrored = bool(np.random.choice([True, False])) + center_x, center_y = self.patch_size / 2.0, self.patch_size / 2.0 + rotation_matrix = cv2.getRotationMatrix2D((center_x, center_y), theta_deg, 1.0) + + if is_mirrored: + translate_to_origin = np.array([[1, 0, -center_x], [0, 1, -center_y], [0, 0, 1]]) + mirror = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) + translate_back = np.array([[1, 0, center_x], [0, 1, center_y], [0, 0, 1]]) + mirror_matrix = translate_back @ mirror @ translate_to_origin + rotation_matrix_h = np.vstack([rotation_matrix, [0, 0, 1]]) + homography = (rotation_matrix_h @ mirror_matrix).astype(np.float32) + else: + homography = np.vstack([rotation_matrix, [0, 0, 1]]).astype(np.float32) + + transformed_patch_np = cv2.warpPerspective(patch_np_uint8, homography, (self.patch_size, self.patch_size)) + transformed_patch = Image.fromarray(transformed_patch_np) + + if self.transform: + patch_tensor = self.transform(patch) + transformed_tensor = self.transform(transformed_patch) + else: + patch_tensor = torch.from_numpy(np.array(patch)).float().unsqueeze(0) / 255.0 + transformed_tensor = torch.from_numpy(np.array(transformed_patch)).float().unsqueeze(0) / 255.0 + + H_tensor = torch.from_numpy(homography[:2, :]).float() + return patch_tensor, transformed_tensor, H_tensor \ No newline at end of file diff --git a/docs/feature_work.md b/docs/feature_work.md new file mode 100644 index 0000000..ce3a5e9 --- /dev/null +++ b/docs/feature_work.md @@ -0,0 +1,166 @@ +# 后续工作 + +本文档整合了 RoRD 项目的优化待办清单和训练需求,用于规划未来的开发和实验工作。 + +--- + +## RoRD 项目优化待办清单 + +本文档旨在为 RoRD (Rotation-Robust Descriptors) 项目提供一系列可行的优化任务。各项任务按优先级和模块划分,您可以根据项目进度和资源情况选择执行。 + +### 一、 数据策略与增强 (Data Strategy & Augmentation) + +> *目标:提升模型的鲁棒性和泛化能力,减少对大量真实数据的依赖。* + +- [ ] **引入弹性变形 (Elastic Transformations)** + - **✔️ 价值**: 模拟芯片制造中可能出现的微小物理形变,使模型对非刚性变化更鲁棒。 + - **📝 执行方案**: + 1. 添加 `albumentations` 库作为项目依赖。 + 2. 在 `train.py` 的 `ICLayoutTrainingDataset` 类中,集成 `A.ElasticTransform` 到数据增强管道中。 +- [ ] **创建合成版图数据生成器** + - **✔️ 价值**: 解决真实版图数据获取难、数量少的问题,通过程序化生成大量多样化的训练样本。 + - **📝 执行方案**: + 1. 创建一个新脚本,例如 `tools/generate_synthetic_layouts.py`。 + 2. 利用 `gdstk` 库 编写函数,程序化地生成包含不同尺寸、密度和类型标准单元的 GDSII 文件。 + 3. 结合 `tools/layout2png.py` 的逻辑,将生成的版图批量转换为 PNG 图像,用于扩充训练集。 + +### 二、 模型架构 (Model Architecture) + +> *目标:提升模型的特征提取效率和精度,降低计算资源消耗。* + +- [ ] **实验更现代的骨干网络 (Backbone)** + - **✔️ 价值**: VGG-16 经典但效率偏低。新架构(如 ResNet, EfficientNet)能以更少的参数量和计算量达到更好的性能。 + - **📝 执行方案**: + 1. 在 `models/rord.py` 中,修改 `RoRD` 类的 `__init__` 方法。 + 2. 使用 `torchvision.models` 替换 `vgg16`。可尝试 `models.resnet34(pretrained=True)` 或 `models.efficientnet_b0(pretrained=True)` 作为替代方案。 + 3. 相应地调整检测头和描述子头的输入通道数。 +- [ ] **集成注意力机制 (Attention Mechanism)** + - **✔️ 价值**: 引导模型自动关注版图中的关键几何结构(如边角、交点),忽略大面积的空白或重复区域,提升特征质量。 + - **📝 执行方案**: + 1. 寻找一个可靠的注意力模块实现,如 CBAM 或 SE-Net。 + 2. 在 `models/rord.py` 中,将该模块插入到 `self.backbone` 和两个 `head` 之间。 + +### 三、 训练与损失函数 (Training & Loss Function) + +> *目标:优化训练过程的稳定性,提升模型收敛效果。* + +- [ ] **实现损失函数的自动加权** + - **✔️ 价值**: 当前检测损失和描述子损失是等权重相加,手动调参困难。自动加权可以使模型自主地平衡不同任务的优化难度。 + - **📝 执行方案**: + 1. 参考学术界关于“多任务学习中的不确定性加权” (Uncertainty Weighting) 的论文。 + 2. 在 `train.py` 中,将损失权重定义为两个可学习的参数 `log_var_a` 和 `log_var_b`。 + 3. 将总损失函数修改为 `loss = torch.exp(-log_var_a) * det_loss + log_var_a + torch.exp(-log_var_b) * desc_loss + log_var_b`。 + 4. 将这两个新参数加入到优化器中进行训练。 +- [ ] **实现基于关键点响应的困难样本采样** + - **✔️ 价值**: 提升描述子学习的效率。只在模型认为是“关键点”的区域进行采样,能让模型更专注于学习有区分度的特征。 + - **📝 执行方案**: + 1. 在 `train.py` 的 `compute_description_loss` 函数中。 + 2. 获取 `det_original` 的输出图,进行阈值处理或 Top-K 选择,得到关键点的位置坐标。 + 3. 使用这些坐标,而不是 `torch.linspace` 生成的网格坐标,作为采样点来提取 `anchor`、`positive` 和 `negative` 描述子。 + +### 四、 推理与匹配 (Inference & Matching) + +> *目标:大幅提升大尺寸版图的匹配速度和多尺度检测能力。* + +- [ ] **将模型改造为特征金字塔网络 (FPN) 架构** + - **✔️ 价值**: 当前的多尺度匹配需要多次缩放图像并推理,速度慢。FPN 只需一次推理即可获得所有尺度的特征,极大加速匹配过程。 + - **📝 执行方案**: + 1. 修改 `models/rord.py`,从骨干网络的不同层级(如 VGG 的 `relu2_2`, `relu3_3`, `relu4_3`)提取特征图。 + 2. 添加上采样和横向连接层来融合这些特征图,构建出特征金字塔。 + 3. 修改 `match.py`,使其能够直接从 FPN 的不同层级获取特征,替代原有的图像金字塔循环。 +- [ ] **在滑动窗口匹配后增加关键点去重** + - **✔️ 价值**: `match.py` 中的滑动窗口在重叠区域会产生大量重复的关键点,增加后续匹配的计算量并可能影响精度。 + - **📝 执行方案**: + 1. 在 `match.py` 的 `extract_features_sliding_window` 函数返回前。 + 2. 实现一个非极大值抑制 (NMS) 算法。 + 3. 根据关键点的位置和检测分数(需要模型输出强度图),对 `all_kps` 和 `all_descs` 进行过滤,去除冗余点。 + +### 五、 代码与项目结构 (Code & Project Structure) + +> *目标:提升项目的可维护性、可扩展性和易用性。* + +- [ ] **迁移配置到 YAML 文件** + - **✔️ 价值**: `config.py` 不利于管理多组实验配置。YAML 文件能让每组实验的参数独立、清晰,便于复现。 + - **📝 执行方案**: + 1. 创建一个 `configs` 目录,并编写一个 `base_config.yaml` 文件。 + 2. 引入 `OmegaConf` 或 `Hydra` 库。 + 3. 修改 `train.py` 和 `match.py` 等脚本,使其从 YAML 文件加载配置,而不是从 `config.py` 导入。 +- [ ] **代码模块解耦** + - **✔️ 价值**: `train.py` 文件过长,职责过多。解耦能使代码结构更清晰,符合单一职责原则。 + - **📝 执行方案**: + 1. 将 `ICLayoutTrainingDataset` 类从 `train.py` 移动到 `data/ic_dataset.py`。 + 2. 创建一个新文件 `losses.py`,将 `compute_detection_loss` 和 `compute_description_loss` 函数移入其中。 + +### 六、 实验跟踪与评估 (Experiment Tracking & Evaluation) + +> *目标:建立科学的实验流程,提供更全面的模型性能度量。* + +- [ ] **集成实验跟踪工具 (TensorBoard / W&B)** + - **✔️ 价值**: 日志文件不利于直观对比实验结果。可视化工具可以实时监控、比较多组实验的损失和评估指标。 + - **📝 执行方案**: + 1. 在 `train.py` 中,导入 `torch.utils.tensorboard.SummaryWriter`。 + 2. 在训练循环中,使用 `writer.add_scalar()` 记录各项损失值。 + 3. 在验证结束后,记录评估指标和学习率等信息。 +- [ ] **增加更全面的评估指标** + - **✔️ 价值**: 当前的评估指标 主要关注检测框的重合度。增加 mAP 和几何误差评估能更全面地衡量模型性能。 + - **📝 执行方案**: + 1. 在 `evaluate.py` 中,实现 mAP (mean Average Precision) 的计算逻辑。 + 2. 在计算 IoU 匹配成功后,从 `match_template_multiscale` 返回的单应性矩阵 `H` 中,分解出旋转/平移等几何参数,并与真实变换进行比较,计算误差。 + +--- + +## 训练需求 + +### 1. 数据集类型 + +* **格式**: 训练数据为PNG格式的集成电路 (IC) 版图图像。这些图像可以是二值化的黑白图,也可以是灰度图。 +* **来源**: 可以从 GDSII (.gds) 或 OASIS (.oas) 版图文件通过光栅化生成。 +* **内容**: 数据集应包含多种不同区域、不同风格的版图,以确保模型的泛化能力。 +* **标注**: **训练阶段无需任何人工标注**。模型采用自监督学习,通过对原图进行旋转、镜像等几何变换来自动生成训练对。 + +### 2. 数据集大小 + +* **启动阶段 (功能验证)**: **100 - 200 张** 高分辨率 (例如:2048x2048) 的版图图像。这个规模足以验证训练流程是否能跑通,损失函数是否收敛。 +* **初步可用模型**: **1,000 - 2,000 张** 版图图像。在这个数量级上,模型能学习到比较鲁棒的几何特征,在与训练数据相似的版图上取得不错的效果。 +* **生产级模型**: **5,000 - 10,000+ 张** 版图图像。要让模型在各种不同工艺、设计风格的版图上都具有良好的泛化能力,需要大规模、多样化的数据集。 + +训练脚本 `train.py` 会将提供的数据集自动按 80/20 的比例划分为训练集和验证集。 + +### 3. 计算资源 + +* **硬件**: **一块支持 CUDA 的 NVIDIA GPU 是必需的**。考虑到模型的 VGG-16 骨干网络和复杂的几何感知损失函数,使用中高端 GPU 会显著提升训练效率。 +* **推荐型号**: + * **入门级**: NVIDIA RTX 3060 / 4060 + * **主流级**: NVIDIA RTX 3080 / 4070 / A4000 + * **专业级**: NVIDIA RTX 3090 / 4090 / A6000 +* **CPU 与内存**: 建议至少 8 核 CPU 和 32 GB 内存,以确保数据预处理和加载不会成为瓶颈。 + +### 4. 显存大小 (VRAM) + +根据配置文件 `config.py` 和 `train.py` 中的参数,可以估算所需显存: + +* **模型架构**: 基于 VGG-16。 +* **批次大小 (Batch Size)**: 默认为 8。 +* **图像块大小 (Patch Size)**: 256x256。 + +综合以上参数,并考虑到梯度和优化器状态的存储开销,**建议至少需要 12 GB 显存**。如果显存不足,需要将 `BATCH_SIZE` 减小 (例如 4 或 2),但这会牺牲训练速度和稳定性。 + +### 5. 训练时间估算 + +假设使用一块 **NVIDIA RTX 3080 (10GB)** 显卡和 **2,000 张** 版图图像的数据集: + +* **单个 Epoch 时间**: 约 15 - 25 分钟。 +* **总训练时间**: 配置文件中设置的总轮数 (Epochs) 为 50。 + * `50 epochs * 20 分钟/epoch ≈ 16.7 小时` +* **收敛时间**: 项目引入了早停机制 (patience=10),如果验证集损失在 10 个 epoch 内没有改善,训练会提前停止。因此,实际训练时间可能在 **10 到 20 小时** 之间。 + +### 6. 逐步调优时间 + +调优是一个迭代过程,非常耗时。根据 `TRAINING_STRATEGY_ANALYSIS.md` 文件中提到的优化点 和进一步优化建议,调优阶段可能包括: + +* **数据增强策略探索 (1-2周)**: 调整尺度抖动范围、亮度和对比度参数,尝试不同的噪声类型等。 +* **损失函数权重平衡 (1-2周)**: `loss_function.md` 中提到了多种损失分量(BCE, SmoothL1, Triplet, Manhattan, Sparsity, Binary),调整它们之间的权重对模型性能至关重要。 +* **超参数搜索 (2-4周)**: 对学习率、批次大小、优化器类型 (Adam, SGD等)、学习率调度策略等进行网格搜索或贝叶斯优化。 +* **模型架构微调 (可选,2-4周)**: 尝试不同的骨干网络 (如 ResNet)、修改检测头和描述子头的层数或通道数。 + +**总计,要达到一个稳定、可靠、泛化能力强的生产级模型,从数据准备到最终调优完成,预计需要 1 个半到 3 个月的时间。** diff --git a/evaluate.py b/evaluate.py index f987615..842c958 100644 --- a/evaluate.py +++ b/evaluate.py @@ -1,17 +1,17 @@ # evaluate.py +import argparse +import json +import os +from pathlib import Path + import torch from PIL import Image -import json -import os -import argparse -import config -from models.rord import RoRD -from utils.data_utils import get_transform -from data.ic_dataset import ICLayoutDataset -# (已修改) 导入新的匹配函数 from match import match_template_multiscale +from models.rord import RoRD +from utils.config_loader import load_config, to_absolute_path +from utils.data_utils import get_transform def compute_iou(box1, box2): x1, y1, w1, h1 = box1['x'], box1['y'], box1['width'], box1['height'] @@ -23,7 +23,7 @@ def compute_iou(box1, box2): return inter_area / union_area if union_area > 0 else 0 # --- (已修改) 评估函数 --- -def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir): +def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir, matching_cfg, iou_threshold): model.eval() all_tp, all_fp, all_fn = 0, 0, 0 @@ -59,7 +59,7 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir): template_image = Image.open(template_path).convert('L') # (已修改) 调用新的多尺度匹配函数 - detected = match_template_multiscale(model, layout_image, template_image, transform) + detected = match_template_multiscale(model, layout_image, template_image, transform, matching_cfg) gt_boxes = gt_by_template.get(template_name, []) @@ -76,7 +76,7 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir): if iou > best_iou: best_iou, best_gt_idx = iou, i - if best_iou > config.IOU_THRESHOLD: + if best_iou > iou_threshold: if not matched_gt[best_gt_idx]: tp += 1 matched_gt[best_gt_idx] = True @@ -96,17 +96,29 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir): if __name__ == "__main__": parser = argparse.ArgumentParser(description="评估 RoRD 模型性能") - parser.add_argument('--model_path', type=str, default=config.MODEL_PATH) - parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR) - parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR) - parser.add_argument('--templates_dir', type=str, default=config.TEMPLATE_DIR) + parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径") + parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径") + parser.add_argument('--val_dir', type=str, default=None, help="验证图像目录,若未提供则使用配置文件中的路径") + parser.add_argument('--annotations_dir', type=str, default=None, help="验证标注目录,若未提供则使用配置文件中的路径") + parser.add_argument('--templates_dir', type=str, default=None, help="模板目录,若未提供则使用配置文件中的路径") args = parser.parse_args() + cfg = load_config(args.config) + config_dir = Path(args.config).resolve().parent + paths_cfg = cfg.paths + matching_cfg = cfg.matching + eval_cfg = cfg.evaluation + + model_path = args.model_path or str(to_absolute_path(paths_cfg.model_path, config_dir)) + val_dir = args.val_dir or str(to_absolute_path(paths_cfg.val_img_dir, config_dir)) + annotations_dir = args.annotations_dir or str(to_absolute_path(paths_cfg.val_ann_dir, config_dir)) + templates_dir = args.templates_dir or str(to_absolute_path(paths_cfg.template_dir, config_dir)) + iou_threshold = float(eval_cfg.iou_threshold) + model = RoRD().cuda() - model.load_state_dict(torch.load(args.model_path)) + model.load_state_dict(torch.load(model_path)) - # (已修改) 不再需要预加载数据集,直接传入路径 - results = evaluate(model, args.val_dir, args.annotations_dir, args.templates_dir) + results = evaluate(model, val_dir, annotations_dir, templates_dir, matching_cfg, iou_threshold) print("\n--- 评估结果 ---") print(f" 精确率 (Precision): {results['precision']:.4f}") diff --git a/match.py b/match.py index cc754c2..88dd7ad 100644 --- a/match.py +++ b/match.py @@ -1,15 +1,17 @@ # match.py -import torch -import torch.nn.functional as F -import numpy as np -import cv2 -from PIL import Image import argparse import os +from pathlib import Path + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image -import config from models.rord import RoRD +from utils.config_loader import load_config, to_absolute_path from utils.data_utils import get_transform # --- 特征提取函数 (基本无变动) --- @@ -39,15 +41,16 @@ def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh): return keypoints, descriptors # --- (新增) 滑动窗口特征提取函数 --- -def extract_features_sliding_window(model, large_image, transform): +def extract_features_sliding_window(model, large_image, transform, matching_cfg): """ 使用滑动窗口从大图上提取所有关键点和描述子 """ print("使用滑动窗口提取大版图特征...") device = next(model.parameters()).device W, H = large_image.size - window_size = config.INFERENCE_WINDOW_SIZE - stride = config.INFERENCE_STRIDE + window_size = int(matching_cfg.inference_window_size) + stride = int(matching_cfg.inference_stride) + keypoint_threshold = float(matching_cfg.keypoint_threshold) all_kps = [] all_descs = [] @@ -65,7 +68,7 @@ def extract_features_sliding_window(model, large_image, transform): patch_tensor = transform(patch).unsqueeze(0).to(device) # 提取特征 - kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, config.KEYPOINT_THRESHOLD) + kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, keypoint_threshold) if len(kps) > 0: # 将局部坐标转换为全局坐标 @@ -94,26 +97,30 @@ def mutual_nearest_neighbor(descs1, descs2): return matches # --- (已修改) 多尺度、多实例匹配主函数 --- -def match_template_multiscale(model, layout_image, template_image, transform): +def match_template_multiscale(model, layout_image, template_image, transform, matching_cfg): """ 在不同尺度下搜索模板,并检测多个实例 """ # 1. 对大版图使用滑动窗口提取全部特征 - layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform) + layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform, matching_cfg) - if len(layout_kps) < config.MIN_INLIERS: + min_inliers = int(matching_cfg.min_inliers) + if len(layout_kps) < min_inliers: print("从大版图中提取的关键点过少,无法进行匹配。") return [] found_instances = [] active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device) + pyramid_scales = [float(s) for s in matching_cfg.pyramid_scales] + keypoint_threshold = float(matching_cfg.keypoint_threshold) + ransac_threshold = float(matching_cfg.ransac_reproj_threshold) # 2. 多实例迭代检测 while True: current_active_indices = torch.nonzero(active_layout_mask).squeeze(1) # 如果剩余活动关键点过少,则停止 - if len(current_active_indices) < config.MIN_INLIERS: + if len(current_active_indices) < min_inliers: break current_layout_kps = layout_kps[current_active_indices] @@ -123,7 +130,7 @@ def match_template_multiscale(model, layout_image, template_image, transform): # 3. 图像金字塔:遍历模板的每个尺度 print("在新尺度下搜索模板...") - for scale in config.PYRAMID_SCALES: + for scale in pyramid_scales: W, H = template_image.size new_W, new_H = int(W * scale), int(H * scale) @@ -132,7 +139,7 @@ def match_template_multiscale(model, layout_image, template_image, transform): template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device) # 提取缩放后模板的特征 - template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, config.KEYPOINT_THRESHOLD) + template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, keypoint_threshold) if len(template_kps) < 4: continue @@ -147,13 +154,13 @@ def match_template_multiscale(model, layout_image, template_image, transform): dst_pts_indices = current_active_indices[matches[:, 1]] dst_pts = layout_kps[dst_pts_indices].cpu().numpy() - H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, config.RANSAC_REPROJ_THRESHOLD) + H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransac_threshold) if H is not None and mask.sum() > best_match_info['inliers']: best_match_info = {'inliers': mask.sum(), 'H': H, 'mask': mask, 'scale': scale, 'dst_pts': dst_pts} # 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽 - if best_match_info['inliers'] > config.MIN_INLIERS: + if best_match_info['inliers'] > min_inliers: print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x") inlier_mask = best_match_info['mask'].ravel().astype(bool) @@ -191,21 +198,27 @@ def visualize_matches(layout_path, bboxes, output_path): if __name__ == "__main__": parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配") - parser.add_argument('--model_path', type=str, default=config.MODEL_PATH) + parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径") + parser.add_argument('--model_path', type=str, default=None, help="模型权重路径,若未提供则使用配置文件中的路径") parser.add_argument('--layout', type=str, required=True) parser.add_argument('--template', type=str, required=True) parser.add_argument('--output', type=str) args = parser.parse_args() + cfg = load_config(args.config) + config_dir = Path(args.config).resolve().parent + matching_cfg = cfg.matching + model_path = args.model_path or str(to_absolute_path(cfg.paths.model_path, config_dir)) + transform = get_transform() model = RoRD().cuda() - model.load_state_dict(torch.load(args.model_path)) + model.load_state_dict(torch.load(model_path)) model.eval() layout_image = Image.open(args.layout).convert('L') template_image = Image.open(args.template).convert('L') - detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform) + detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform, matching_cfg) print("\n检测到的边界框:") for bbox in detected_bboxes: diff --git a/pyproject.toml b/pyproject.toml index fe3f458..d19550e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "pillow>=11.2.1", "torch>=2.7.1", "torchvision>=0.22.1", + "omegaconf>=2.3.0", ] [[tool.uv.index]] diff --git a/train.py b/train.py index 206865c..fb750fe 100644 --- a/train.py +++ b/train.py @@ -1,20 +1,18 @@ # train.py -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import Dataset, DataLoader -from PIL import Image -import numpy as np -import cv2 -import os import argparse import logging +import os from datetime import datetime +from pathlib import Path -# 导入项目模块 -import config +import torch +from torch.utils.data import DataLoader + +from data.ic_dataset import ICLayoutTrainingDataset +from losses import compute_detection_loss, compute_description_loss from models.rord import RoRD +from utils.config_loader import load_config, to_absolute_path from utils.data_utils import get_transform # 设置日志记录 @@ -34,207 +32,33 @@ def setup_logging(save_dir): ) return logging.getLogger(__name__) -# --- (已修改) 训练专用数据集类 --- -class ICLayoutTrainingDataset(Dataset): - def __init__(self, image_dir, patch_size=256, transform=None, scale_range=(1.0, 1.0)): - self.image_dir = image_dir - self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')] - self.patch_size = patch_size - self.transform = transform - self.scale_range = scale_range # 新增尺度范围参数 - - def __len__(self): - return len(self.image_paths) - - def __getitem__(self, index): - img_path = self.image_paths[index] - image = Image.open(img_path).convert('L') - W, H = image.size - - # --- 新增:尺度抖动数据增强 --- - # 1. 随机选择一个缩放比例 - scale = np.random.uniform(self.scale_range[0], self.scale_range[1]) - # 2. 根据缩放比例计算需要从原图裁剪的尺寸 - crop_size = int(self.patch_size / scale) - - # 确保裁剪尺寸不超过图像边界 - if crop_size > min(W, H): - crop_size = min(W, H) - - # 3. 随机裁剪 - x = np.random.randint(0, W - crop_size + 1) - y = np.random.randint(0, H - crop_size + 1) - patch = image.crop((x, y, x + crop_size, y + crop_size)) - - # 4. 将裁剪出的图像块缩放回标准的 patch_size - patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS) - # --- 尺度抖动结束 --- - - # --- 新增:额外的数据增强 --- - # 亮度调整 - if np.random.random() < 0.5: - brightness_factor = np.random.uniform(0.8, 1.2) - patch = patch.point(lambda x: int(x * brightness_factor)) - - # 对比度调整 - if np.random.random() < 0.5: - contrast_factor = np.random.uniform(0.8, 1.2) - patch = patch.point(lambda x: int(((x - 128) * contrast_factor) + 128)) - - # 添加噪声 - if np.random.random() < 0.3: - patch_np = np.array(patch, dtype=np.float32) - noise = np.random.normal(0, 5, patch_np.shape) - patch_np = np.clip(patch_np + noise, 0, 255) - patch = Image.fromarray(patch_np.astype(np.uint8)) - # --- 额外数据增强结束 --- - - patch_np = np.array(patch) - - # 实现8个方向的离散几何变换 (这部分逻辑不变) - theta_deg = np.random.choice([0, 90, 180, 270]) - is_mirrored = np.random.choice([True, False]) - cx, cy = self.patch_size / 2.0, self.patch_size / 2.0 - M = cv2.getRotationMatrix2D((cx, cy), theta_deg, 1) - - if is_mirrored: - T1 = np.array([[1, 0, -cx], [0, 1, -cy], [0, 0, 1]]) - Flip = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) - T2 = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) - M_mirror_3x3 = T2 @ Flip @ T1 - M_3x3 = np.vstack([M, [0, 0, 1]]) - H = (M_3x3 @ M_mirror_3x3).astype(np.float32) - else: - H = np.vstack([M, [0, 0, 1]]).astype(np.float32) - - transformed_patch_np = cv2.warpPerspective(patch_np, H, (self.patch_size, self.patch_size)) - transformed_patch = Image.fromarray(transformed_patch_np) - - if self.transform: - patch = self.transform(patch) - transformed_patch = self.transform(transformed_patch) - - H_tensor = torch.from_numpy(H[:2, :]).float() - return patch, transformed_patch, H_tensor - -# --- 特征图变换与损失函数 (改进版) --- -def warp_feature_map(feature_map, H_inv): - B, C, H, W = feature_map.size() - grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device) - return F.grid_sample(feature_map, grid, align_corners=False) - -def compute_detection_loss(det_original, det_rotated, H): - """改进的检测损失:使用BCE损失替代MSE""" - with torch.no_grad(): - H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :] - warped_det_rotated = warp_feature_map(det_rotated, H_inv) - - # 使用BCE损失,更适合二分类问题 - bce_loss = F.binary_cross_entropy(det_original, warped_det_rotated) - - # 添加平滑L1损失作为辅助 - smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det_rotated) - - return bce_loss + 0.1 * smooth_l1_loss - -def compute_description_loss(desc_original, desc_rotated, H, margin=1.0): - """IC版图专用几何感知描述子损失:编码曼哈顿几何特征""" - B, C, H_feat, W_feat = desc_original.size() - - # 曼哈顿几何感知采样:重点采样边缘和角点区域 - num_samples = 200 - - # 生成曼哈顿对齐的采样网格(水平和垂直优先) - h_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device) - w_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device) - - # 增加曼哈顿方向的采样密度 - manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)]) - manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords]) - manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1).unsqueeze(0).repeat(B, 1, 1) - - # 采样anchor点 - anchor = F.grid_sample(desc_original, manhattan_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) - - # 计算对应的正样本点 - coords_hom = torch.cat([manhattan_coords, torch.ones(B, manhattan_coords.size(1), 1, device=manhattan_coords.device)], dim=2) - M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1)) - coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2] - positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) - - # IC版图专用负样本策略:考虑重复结构 - with torch.no_grad(): - # 1. 几何感知的负样本:曼哈顿变换后的不同区域 - neg_coords = [] - for b in range(B): - # 生成曼哈顿变换后的坐标(90度旋转等) - angles = [0, 90, 180, 270] - for angle in angles: - if angle != 0: - theta = torch.tensor([angle * np.pi / 180]) - rot_matrix = torch.tensor([ - [torch.cos(theta), -torch.sin(theta), 0], - [torch.sin(theta), torch.cos(theta), 0] - ]) - rotated_coords = manhattan_coords[b] @ rot_matrix[:2, :2].T - neg_coords.append(rotated_coords) - - neg_coords = torch.stack(neg_coords[:B*num_samples//2]).reshape(B, -1, 2) - - # 2. 特征空间困难负样本 - negative_candidates = F.grid_sample(desc_rotated, neg_coords, align_corners=False).squeeze(2).transpose(1, 2) - - # 3. 曼哈顿距离约束的困难样本选择 - anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1) - negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1) - - # 使用曼哈顿距离而非欧氏距离 - manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3) - hard_indices = torch.topk(manhattan_dist, k=anchor.size(1)//2, largest=False)[1] - negative = torch.gather(negative_candidates, 1, hard_indices) - - # IC版图专用的几何一致性损失 - # 1. 曼哈顿方向一致性损失 - manhattan_loss = 0 - for i in range(anchor.size(1)): - # 计算水平和垂直方向的几何一致性 - anchor_norm = F.normalize(anchor[:, i], p=2, dim=1) - positive_norm = F.normalize(positive[:, i], p=2, dim=1) - - # 鼓励描述子对曼哈顿变换不变 - cos_sim = torch.sum(anchor_norm * positive_norm, dim=1) - manhattan_loss += torch.mean(1 - cos_sim) - - # 2. 稀疏性正则化(IC版图特征稀疏) - sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive)) - - # 3. 二值化特征距离(处理二值化输入) - binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive))) - - # 综合损失 - triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean') # 使用L1距离 - geometric_triplet = triplet_loss(anchor, positive, negative) - - return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss - # --- (已修改) 主函数与命令行接口 --- def main(args): - # 设置日志记录 - logger = setup_logging(args.save_dir) - + cfg = load_config(args.config) + config_dir = Path(args.config).resolve().parent + + data_dir = args.data_dir or str(to_absolute_path(cfg.paths.layout_dir, config_dir)) + save_dir = args.save_dir or str(to_absolute_path(cfg.paths.save_dir, config_dir)) + epochs = args.epochs if args.epochs is not None else int(cfg.training.num_epochs) + batch_size = args.batch_size if args.batch_size is not None else int(cfg.training.batch_size) + lr = args.lr if args.lr is not None else float(cfg.training.learning_rate) + patch_size = int(cfg.training.patch_size) + scale_range = tuple(float(x) for x in cfg.training.scale_jitter_range) + + logger = setup_logging(save_dir) + logger.info("--- 开始训练 RoRD 模型 ---") - logger.info(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}") - logger.info(f"数据目录: {args.data_dir}") - logger.info(f"保存目录: {args.save_dir}") - + logger.info(f"训练参数: Epochs={epochs}, Batch Size={batch_size}, LR={lr}") + logger.info(f"数据目录: {data_dir}") + logger.info(f"保存目录: {save_dir}") + transform = get_transform() - - # 在数据集初始化时传入尺度抖动范围 + dataset = ICLayoutTrainingDataset( - args.data_dir, - patch_size=config.PATCH_SIZE, - transform=transform, - scale_range=config.SCALE_JITTER_RANGE + data_dir, + patch_size=patch_size, + transform=transform, + scale_range=scale_range, ) logger.info(f"数据集大小: {len(dataset)}") @@ -246,13 +70,13 @@ def main(args): logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}") - train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) - val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) + val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4) model = RoRD().cuda() logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}") - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) + optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) # 添加学习率调度器 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( @@ -264,7 +88,7 @@ def main(args): patience_counter = 0 patience = 10 - for epoch in range(args.epochs): + for epoch in range(epochs): # 训练阶段 model.train() total_train_loss = 0 @@ -339,18 +163,19 @@ def main(args): patience_counter = 0 # 保存最佳模型 - if not os.path.exists(args.save_dir): - os.makedirs(args.save_dir) - save_path = os.path.join(args.save_dir, 'rord_model_best.pth') + if not os.path.exists(save_dir): + os.makedirs(save_dir) + save_path = os.path.join(save_dir, 'rord_model_best.pth') torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_val_loss': best_val_loss, 'config': { - 'learning_rate': args.lr, - 'batch_size': args.batch_size, - 'epochs': args.epochs + 'learning_rate': lr, + 'batch_size': batch_size, + 'epochs': epochs, + 'config_path': str(Path(args.config).resolve()), } }, save_path) logger.info(f"最佳模型已保存至: {save_path}") @@ -361,16 +186,17 @@ def main(args): break # 保存最终模型 - save_path = os.path.join(args.save_dir, 'rord_model_final.pth') + save_path = os.path.join(save_dir, 'rord_model_final.pth') torch.save({ - 'epoch': args.epochs, + 'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'final_val_loss': avg_val_loss, 'config': { - 'learning_rate': args.lr, - 'batch_size': args.batch_size, - 'epochs': args.epochs + 'learning_rate': lr, + 'batch_size': batch_size, + 'epochs': epochs, + 'config_path': str(Path(args.config).resolve()), } }, save_path) logger.info(f"最终模型已保存至: {save_path}") @@ -378,9 +204,10 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="训练 RoRD 模型") - parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR) - parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR) - parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS) - parser.add_argument('--batch_size', type=int, default=config.BATCH_SIZE) - parser.add_argument('--lr', type=float, default=config.LEARNING_RATE) + parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径") + parser.add_argument('--data_dir', type=str, default=None, help="训练数据目录,若未提供则使用配置文件中的路径") + parser.add_argument('--save_dir', type=str, default=None, help="模型保存目录,若未提供则使用配置文件中的路径") + parser.add_argument('--epochs', type=int, default=None, help="训练轮数,若未提供则使用配置文件中的值") + parser.add_argument('--batch_size', type=int, default=None, help="批次大小,若未提供则使用配置文件中的值") + parser.add_argument('--lr', type=float, default=None, help="学习率,若未提供则使用配置文件中的值") main(parser.parse_args()) \ No newline at end of file diff --git a/uv.lock b/uv.lock index bfa7e86..93f9aa6 100644 --- a/uv.lock +++ b/uv.lock @@ -7,6 +7,12 @@ resolution-markers = [ "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')", ] +[[package]] +name = "antlr4-python3-runtime" +version = "4.9.3" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b", size = 117034, upload-time = "2021-11-06T17:52:23.524Z" } + [[package]] name = "cairocffi" version = "1.7.1" @@ -408,6 +414,19 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9e/4e/0d0c945463719429b7bd21dece907ad0bde437a2ff12b9b12fee94722ab0/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1", size = 89265, upload-time = "2024-10-01T17:00:38.172Z" }, ] +[[package]] +name = "omegaconf" +version = "2.3.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "antlr4-python3-runtime" }, + { name = "pyyaml" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/48/6388f1bb9da707110532cb70ec4d2822858ddfb44f1cdf1233c20a80ea4b/omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7", size = 3298120, upload-time = "2022-12-08T20:59:22.753Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", size = 79500, upload-time = "2022-12-08T20:59:19.686Z" }, +] + [[package]] name = "opencv-python" version = "4.11.0.86" @@ -475,6 +494,32 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552, upload-time = "2024-03-30T13:22:20.476Z" }, ] +[[package]] +name = "pyyaml" +version = "6.0.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631, upload-time = "2024-08-06T20:33:50.674Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/0c/c581167fc46d6d6d7ddcfb8c843a4de25bdd27e4466938109ca68492292c/PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab", size = 183873, upload-time = "2024-08-06T20:32:25.131Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a8/0c/38374f5bb272c051e2a69281d71cba6fdb983413e6758b84482905e29a5d/PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725", size = 173302, upload-time = "2024-08-06T20:32:26.511Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c3/93/9916574aa8c00aa06bbac729972eb1071d002b8e158bd0e83a3b9a20a1f7/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", size = 739154, upload-time = "2024-08-06T20:32:28.363Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/95/0f/b8938f1cbd09739c6da569d172531567dbcc9789e0029aa070856f123984/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", size = 766223, upload-time = "2024-08-06T20:32:30.058Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b9/2b/614b4752f2e127db5cc206abc23a8c19678e92b23c3db30fc86ab731d3bd/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", size = 767542, upload-time = "2024-08-06T20:32:31.881Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d4/00/dd137d5bcc7efea1836d6264f049359861cf548469d18da90cd8216cf05f/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", size = 731164, upload-time = "2024-08-06T20:32:37.083Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c9/1f/4f998c900485e5c0ef43838363ba4a9723ac0ad73a9dc42068b12aaba4e4/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", size = 756611, upload-time = "2024-08-06T20:32:38.898Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/df/d1/f5a275fdb252768b7a11ec63585bc38d0e87c9e05668a139fea92b80634c/PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", size = 140591, upload-time = "2024-08-06T20:32:40.241Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338, upload-time = "2024-08-06T20:32:41.93Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ef/e3/3af305b830494fa85d95f6d95ef7fa73f2ee1cc8ef5b495c7c3269fb835f/PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba", size = 181309, upload-time = "2024-08-06T20:32:43.4Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/9f/3b1c20a0b7a3200524eb0076cc027a970d320bd3a6592873c85c92a08731/PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1", size = 171679, upload-time = "2024-08-06T20:32:44.801Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7c/9a/337322f27005c33bcb656c655fa78325b730324c78620e8328ae28b64d0c/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133", size = 733428, upload-time = "2024-08-06T20:32:46.432Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a3/69/864fbe19e6c18ea3cc196cbe5d392175b4cf3d5d0ac1403ec3f2d237ebb5/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484", size = 763361, upload-time = "2024-08-06T20:32:51.188Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/04/24/b7721e4845c2f162d26f50521b825fb061bc0a5afcf9a386840f23ea19fa/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5", size = 759523, upload-time = "2024-08-06T20:32:53.019Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2b/b2/e3234f59ba06559c6ff63c4e10baea10e5e7df868092bf9ab40e5b9c56b6/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc", size = 726660, upload-time = "2024-08-06T20:32:54.708Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/fe/0f/25911a9f080464c59fab9027482f822b86bf0608957a5fcc6eaac85aa515/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652", size = 751597, upload-time = "2024-08-06T20:32:56.985Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/14/0d/e2c3b43bbce3cf6bd97c840b46088a3031085179e596d4929729d8d68270/PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183", size = 140527, upload-time = "2024-08-06T20:33:03.001Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446, upload-time = "2024-08-06T20:33:04.33Z" }, +] + [[package]] name = "rord-layout-recognation" version = "0.1.0" @@ -485,6 +530,7 @@ dependencies = [ { name = "gdstk" }, { name = "klayout" }, { name = "numpy" }, + { name = "omegaconf" }, { name = "opencv-python" }, { name = "pillow" }, { name = "torch" }, @@ -498,6 +544,7 @@ requires-dist = [ { name = "gdstk", specifier = ">=0.9.60" }, { name = "klayout", specifier = ">=0.30.2" }, { name = "numpy", specifier = ">=2.3.0" }, + { name = "omegaconf", specifier = ">=2.3.0" }, { name = "opencv-python", specifier = ">=4.11.0.86" }, { name = "pillow", specifier = ">=11.2.1" }, { name = "torch", specifier = ">=2.7.1" }, From 09f513686dee8f11ee8aa0b150cecbadbfdb86a2 Mon Sep 17 00:00:00 2001 From: Jiao77 Date: Thu, 25 Sep 2025 20:29:56 +0800 Subject: [PATCH 2/3] add some change --- tools/klayoutconvertor.py | 159 -------------------------------------- 1 file changed, 159 deletions(-) delete mode 100644 tools/klayoutconvertor.py diff --git a/tools/klayoutconvertor.py b/tools/klayoutconvertor.py deleted file mode 100644 index 6ca5747..0000000 --- a/tools/klayoutconvertor.py +++ /dev/null @@ -1,159 +0,0 @@ -# tools/klayoutconvertor.py -#!/usr/bin/env python3 -""" -KLayout GDS to PNG Converter - -This script uses KLayout's Python API to convert GDS files to PNG images. -It accepts command-line arguments for input parameters. - -Requirements: - pip install klayout - -Usage: - python klayoutconvertor.py input.gds output.png [options] -""" - -import klayout.db as pya -import klayout.lay as lay -from PIL import Image -import os -import argparse -import sys - -Image.MAX_IMAGE_PIXELS = None - - -def export_gds_as_image( - gds_path: str, - output_path: str, - layers: list = [1, 2], - center_um: tuple = (0, 0), - view_size_um: float = 100.0, - resolution: int = 2048, - binarize: bool = True -) -> None: - """ - Export GDS file as PNG image using KLayout. - - Args: - gds_path: Input GDS file path - output_path: Output PNG file path - layers: List of layer numbers to include - center_um: Center coordinates in micrometers (x, y) - view_size_um: View size in micrometers - resolution: Output image resolution - binarize: Whether to convert to black and white - """ - if not os.path.exists(gds_path): - raise FileNotFoundError(f"Input file not found: {gds_path}") - - # Ensure output directory exists - output_dir = os.path.dirname(output_path) - if output_dir: - os.makedirs(output_dir, exist_ok=True) - - layout = pya.Layout() - layout.read(gds_path) - top = layout.top_cell() - - # Create layout view - view = lay.LayoutView() - view.set_config("background-color", "#ffffff") - view.set_config("grid-visible", "false") - - # Load layout into view correctly - view.load_layout(gds_path) - - # Add all layers - view.add_missing_layers() - - # Configure view to show entire layout with reasonable resolution - if view_size_um > 0: - # Use specified view size - box = pya.DBox( - center_um[0] - view_size_um / 2, - center_um[1] - view_size_um / 2, - center_um[0] + view_size_um / 2, - center_um[1] + view_size_um / 2 - ) - else: - # Use full layout bounds with size limit - bbox = top.bbox() - if bbox: - # Convert to micrometers (KLayout uses database units) - dbu = layout.dbu - box = pya.DBox( - bbox.left * dbu, - bbox.bottom * dbu, - bbox.right * dbu, - bbox.top * dbu - ) - - else: - # Fallback to 100x100 um if empty layout - box = pya.DBox(-50, -50, 50, 50) - - view.max_hier() - view.zoom_box(box) - - # Save to temporary file first, then load with PIL - import tempfile - temp_path = tempfile.NamedTemporaryFile(suffix='.png', delete=False).name - - try: - view.save_image(temp_path, resolution, resolution) - img = Image.open(temp_path) - - if binarize: - # Convert to grayscale and binarize - img = img.convert("L") - img = img.point(lambda x: 255 if x > 128 else 0, '1') - else: - # Convert to grayscale - img = img.convert("L") - - img.save(output_path) - finally: - # Clean up temp file - if os.path.exists(temp_path): - os.unlink(temp_path) - - -def main(): - """Main CLI entry point.""" - parser = argparse.ArgumentParser(description='Convert GDS to PNG using KLayout') - parser.add_argument('input', help='Input GDS file') - parser.add_argument('output', help='Output PNG file') - parser.add_argument('--layers', nargs='+', type=int, default=[1, 2], - help='Layers to include (default: 1 2)') - parser.add_argument('--center-x', type=float, default=0, - help='Center X coordinate in micrometers (default: 0)') - parser.add_argument('--center-y', type=float, default=0, - help='Center Y coordinate in micrometers (default: 0)') - parser.add_argument('--size', type=float, default=0, - help='View size in micrometers (default: 0 = full layout)') - parser.add_argument('--resolution', type=int, default=2048, - help='Output image resolution (default: 2048)') - parser.add_argument('--no-binarize', action='store_true', - help='Disable binarization (keep grayscale)') - - args = parser.parse_args() - - try: - export_gds_as_image( - gds_path=args.input, - output_path=args.output, - layers=args.layers, - center_um=(args.center_x, args.center_y), - view_size_um=args.size, - resolution=args.resolution, - binarize=not args.no_binarize - ) - print("Conversion completed successfully!") - except Exception as e: - print(f"Error: {e}") - sys.exit(1) - - -if __name__ == '__main__': - main() \ No newline at end of file From 8c6c5592b67180588bdd72ba88697287c0d6226c Mon Sep 17 00:00:00 2001 From: Jiao77 Date: Thu, 25 Sep 2025 20:30:31 +0800 Subject: [PATCH 3/3] add some change twice --- losses.py | 138 +++++++++++++++++++++++++++++++++++++++++ utils/config_loader.py | 23 +++++++ 2 files changed, 161 insertions(+) create mode 100644 losses.py create mode 100644 utils/config_loader.py diff --git a/losses.py b/losses.py new file mode 100644 index 0000000..940610a --- /dev/null +++ b/losses.py @@ -0,0 +1,138 @@ +"""Loss utilities for RoRD training.""" +from __future__ import annotations + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _augment_homography_matrix(h_2x3: torch.Tensor) -> torch.Tensor: + """Append the third row [0, 0, 1] to build a full 3x3 homography.""" + if h_2x3.dim() != 3 or h_2x3.size(1) != 2 or h_2x3.size(2) != 3: + raise ValueError("Expected homography with shape (B, 2, 3)") + + batch_size = h_2x3.size(0) + device = h_2x3.device + bottom_row = torch.tensor([0.0, 0.0, 1.0], device=device, dtype=h_2x3.dtype) + bottom_row = bottom_row.view(1, 1, 3).expand(batch_size, -1, -1) + return torch.cat([h_2x3, bottom_row], dim=1) + + +def warp_feature_map(feature_map: torch.Tensor, h_inv: torch.Tensor) -> torch.Tensor: + """Warp feature map according to inverse homography.""" + return F.grid_sample( + feature_map, + F.affine_grid(h_inv, feature_map.size(), align_corners=False), + align_corners=False, + ) + + +def compute_detection_loss( + det_original: torch.Tensor, + det_rotated: torch.Tensor, + h: torch.Tensor, +) -> torch.Tensor: + """Binary cross-entropy + smooth L1 detection loss.""" + h_full = _augment_homography_matrix(h) + h_inv = torch.inverse(h_full)[:, :2, :] + warped_det = warp_feature_map(det_rotated, h_inv) + + bce_loss = F.binary_cross_entropy(det_original, warped_det) + smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det) + return bce_loss + 0.1 * smooth_l1_loss + + +def compute_description_loss( + desc_original: torch.Tensor, + desc_rotated: torch.Tensor, + h: torch.Tensor, + margin: float = 1.0, +) -> torch.Tensor: + """Triplet-style descriptor loss with Manhattan-aware sampling.""" + batch_size, channels, height, width = desc_original.size() + num_samples = 200 + + grid_side = int(math.sqrt(num_samples)) + h_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device) + w_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device) + + manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)]) + manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords]) + manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1) + manhattan_coords = manhattan_coords.unsqueeze(0).repeat(batch_size, 1, 1) + + anchor = F.grid_sample( + desc_original, + manhattan_coords.unsqueeze(1), + align_corners=False, + ).squeeze(2).transpose(1, 2) + + coords_hom = torch.cat( + [manhattan_coords, torch.ones(batch_size, manhattan_coords.size(1), 1, device=desc_original.device)], + dim=2, + ) + + h_full = _augment_homography_matrix(h) + h_inv = torch.inverse(h_full) + coords_transformed = (coords_hom @ h_inv.transpose(1, 2))[:, :, :2] + + positive = F.grid_sample( + desc_rotated, + coords_transformed.unsqueeze(1), + align_corners=False, + ).squeeze(2).transpose(1, 2) + + negative_list = [] + if manhattan_coords.size(1) > 0: + angles = [0, 90, 180, 270] + for angle in angles: + if angle == 0: + continue + theta = torch.tensor(angle * math.pi / 180.0, device=desc_original.device) + cos_t = torch.cos(theta) + sin_t = torch.sin(theta) + rot = torch.stack( + [ + torch.stack([cos_t, -sin_t]), + torch.stack([sin_t, cos_t]), + ] + ) + rotated_coords = manhattan_coords @ rot.T + negative_list.append(rotated_coords) + + if negative_list: + neg_coords = torch.stack(negative_list, dim=1).reshape(batch_size, -1, 2) + negative_candidates = F.grid_sample( + desc_rotated, + neg_coords.unsqueeze(1), + align_corners=False, + ).squeeze(2).transpose(1, 2) + + anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1) + negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1) + manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3) + + k = max(anchor.size(1) // 2, 1) + hard_indices = torch.topk(manhattan_dist, k=k, largest=False)[1] + idx_expand = hard_indices.unsqueeze(-1).expand(-1, -1, -1, negative_candidates.size(2)) + negative = torch.gather(negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1), 2, idx_expand) + negative = negative.mean(dim=2) + else: + negative = torch.zeros_like(anchor) + + triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean') + geometric_triplet = triplet_loss(anchor, positive, negative) + + manhattan_loss = 0.0 + for i in range(anchor.size(1)): + anchor_norm = F.normalize(anchor[:, i], p=2, dim=1) + positive_norm = F.normalize(positive[:, i], p=2, dim=1) + cos_sim = torch.sum(anchor_norm * positive_norm, dim=1) + manhattan_loss += torch.mean(1 - cos_sim) + + manhattan_loss = manhattan_loss / max(anchor.size(1), 1) + sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive)) + binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive))) + + return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss diff --git a/utils/config_loader.py b/utils/config_loader.py new file mode 100644 index 0000000..b9b1cf8 --- /dev/null +++ b/utils/config_loader.py @@ -0,0 +1,23 @@ +"""Configuration loading utilities using OmegaConf.""" +from __future__ import annotations + +from pathlib import Path +from typing import Union + +from omegaconf import DictConfig, OmegaConf + + +def load_config(config_path: Union[str, Path]) -> DictConfig: + """Load a YAML configuration file into a DictConfig.""" + path = Path(config_path) + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {path}") + return OmegaConf.load(path) + + +def to_absolute_path(path_str: str, base_dir: Union[str, Path]) -> Path: + """Resolve a possibly relative path against the configuration file directory.""" + path = Path(path_str).expanduser() + if path.is_absolute(): + return path.resolve() + return (Path(base_dir) / path).resolve()