chenge to english version
This commit is contained in:
22
config.py
22
config.py
@@ -1,26 +1,26 @@
|
|||||||
# config.py
|
# config.py
|
||||||
|
|
||||||
# --- 训练参数 ---
|
# --- Training Parameters ---
|
||||||
LEARNING_RATE = 5e-5 # 降低学习率,提高训练稳定性
|
LEARNING_RATE = 5e-5 # Reduce learning rate for improved training stability
|
||||||
BATCH_SIZE = 8 # 增加批次大小,提高训练效率
|
BATCH_SIZE = 8 # Increase batch size for improved training efficiency
|
||||||
NUM_EPOCHS = 50 # 增加训练轮数
|
NUM_EPOCHS = 50 # Increase training epochs
|
||||||
PATCH_SIZE = 256
|
PATCH_SIZE = 256
|
||||||
# (优化) 训练时尺度抖动范围 - 缩小范围提高稳定性
|
# (Optimization) Scale jitter range during training - reduced range for improved stability
|
||||||
SCALE_JITTER_RANGE = (0.8, 1.2)
|
SCALE_JITTER_RANGE = (0.8, 1.2)
|
||||||
|
|
||||||
# --- 匹配与评估参数 ---
|
# --- Matching and Evaluation Parameters ---
|
||||||
KEYPOINT_THRESHOLD = 0.5
|
KEYPOINT_THRESHOLD = 0.5
|
||||||
RANSAC_REPROJ_THRESHOLD = 5.0
|
RANSAC_REPROJ_THRESHOLD = 5.0
|
||||||
MIN_INLIERS = 15
|
MIN_INLIERS = 15
|
||||||
IOU_THRESHOLD = 0.5
|
IOU_THRESHOLD = 0.5
|
||||||
# (新增) 推理时模板匹配的图像金字塔尺度
|
# (New) Image pyramid scales for template matching during inference
|
||||||
PYRAMID_SCALES = [0.75, 1.0, 1.5]
|
PYRAMID_SCALES = [0.75, 1.0, 1.5]
|
||||||
# (新增) 推理时处理大版图的滑动窗口参数
|
# (New) Sliding window parameters for processing large layouts during inference
|
||||||
INFERENCE_WINDOW_SIZE = 1024
|
INFERENCE_WINDOW_SIZE = 1024
|
||||||
INFERENCE_STRIDE = 768 # 小于INFERENCE_WINDOW_SIZE以保证重叠
|
INFERENCE_STRIDE = 768 # Less than INFERENCE_WINDOW_SIZE to ensure overlap
|
||||||
|
|
||||||
# --- 文件路径 ---
|
# --- File Paths ---
|
||||||
# (路径保持不变, 请根据您的环境修改)
|
# (Paths remain unchanged, please modify according to your environment)
|
||||||
LAYOUT_DIR = 'path/to/layouts'
|
LAYOUT_DIR = 'path/to/layouts'
|
||||||
SAVE_DIR = 'path/to/save'
|
SAVE_DIR = 'path/to/save'
|
||||||
VAL_IMG_DIR = 'path/to/val/images'
|
VAL_IMG_DIR = 'path/to/val/images'
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ import json
|
|||||||
class ICLayoutDataset(Dataset):
|
class ICLayoutDataset(Dataset):
|
||||||
def __init__(self, image_dir, annotation_dir=None, transform=None):
|
def __init__(self, image_dir, annotation_dir=None, transform=None):
|
||||||
"""
|
"""
|
||||||
初始化 IC 版图数据集。
|
Initialize the IC layout dataset.
|
||||||
|
|
||||||
参数:
|
Args:
|
||||||
image_dir (str): 存储 PNG 格式 IC 版图图像的目录路径。
|
image_dir (str): Directory path containing PNG format IC layout images.
|
||||||
annotation_dir (str, optional): 存储 JSON 格式注释文件的目录路径。
|
annotation_dir (str, optional): Directory path containing JSON format annotation files.
|
||||||
transform (callable, optional): 应用于图像的可选变换(如 Sobel 边缘检测)。
|
transform (callable, optional): Optional transform to apply to images (e.g., Sobel edge detection).
|
||||||
"""
|
"""
|
||||||
self.image_dir = image_dir
|
self.image_dir = image_dir
|
||||||
self.annotation_dir = annotation_dir
|
self.annotation_dir = annotation_dir
|
||||||
@@ -24,25 +24,25 @@ class ICLayoutDataset(Dataset):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""
|
||||||
返回数据集中的图像数量。
|
Return the number of images in the dataset.
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
int: 数据集大小。
|
int: Dataset size.
|
||||||
"""
|
"""
|
||||||
return len(self.images)
|
return len(self.images)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
"""
|
"""
|
||||||
获取指定索引的图像和注释。
|
Get image and annotation at specified index.
|
||||||
|
|
||||||
参数:
|
Args:
|
||||||
idx (int): 图像索引。
|
idx (int): Image index.
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
tuple: (image, annotation),image 为处理后的图像,annotation 为注释字典或空字典。
|
tuple: (image, annotation), where image is the processed image and annotation is the annotation dict or empty dict.
|
||||||
"""
|
"""
|
||||||
img_path = os.path.join(self.image_dir, self.images[idx])
|
img_path = os.path.join(self.image_dir, self.images[idx])
|
||||||
image = Image.open(img_path).convert('L') # 转换为灰度图
|
image = Image.open(img_path).convert('L') # Convert to grayscale
|
||||||
if self.transform:
|
if self.transform:
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
|
|
||||||
|
|||||||
36
evaluate.py
36
evaluate.py
@@ -10,7 +10,7 @@ import config
|
|||||||
from models.rord import RoRD
|
from models.rord import RoRD
|
||||||
from utils.data_utils import get_transform
|
from utils.data_utils import get_transform
|
||||||
from data.ic_dataset import ICLayoutDataset
|
from data.ic_dataset import ICLayoutDataset
|
||||||
# (已修改) 导入新的匹配函数
|
# (Modified) Import new matching function
|
||||||
from match import match_template_multiscale
|
from match import match_template_multiscale
|
||||||
|
|
||||||
def compute_iou(box1, box2):
|
def compute_iou(box1, box2):
|
||||||
@@ -22,48 +22,48 @@ def compute_iou(box1, box2):
|
|||||||
union_area = w1 * h1 + w2 * h2 - inter_area
|
union_area = w1 * h1 + w2 * h2 - inter_area
|
||||||
return inter_area / union_area if union_area > 0 else 0
|
return inter_area / union_area if union_area > 0 else 0
|
||||||
|
|
||||||
# --- (已修改) 评估函数 ---
|
# --- (Modified) Evaluation function ---
|
||||||
def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
|
def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
|
||||||
model.eval()
|
model.eval()
|
||||||
all_tp, all_fp, all_fn = 0, 0, 0
|
all_tp, all_fp, all_fn = 0, 0, 0
|
||||||
|
|
||||||
# 只需要一个统一的 transform 给匹配函数内部使用
|
# Only need a unified transform for internal use by matching function
|
||||||
transform = get_transform()
|
transform = get_transform()
|
||||||
|
|
||||||
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
|
template_paths = [os.path.join(template_dir, f) for f in os.listdir(template_dir) if f.endswith('.png')]
|
||||||
layout_image_names = [f for f in os.listdir(val_dataset_dir) if f.endswith('.png')]
|
layout_image_names = [f for f in os.listdir(val_dataset_dir) if f.endswith('.png')]
|
||||||
|
|
||||||
# (已修改) 循环遍历验证集中的每个版图文件
|
# (Modified) Loop through each layout file in validation set
|
||||||
for layout_name in layout_image_names:
|
for layout_name in layout_image_names:
|
||||||
print(f"\n正在评估版图: {layout_name}")
|
print(f"\nEvaluating layout: {layout_name}")
|
||||||
layout_path = os.path.join(val_dataset_dir, layout_name)
|
layout_path = os.path.join(val_dataset_dir, layout_name)
|
||||||
annotation_path = os.path.join(val_annotations_dir, layout_name.replace('.png', '.json'))
|
annotation_path = os.path.join(val_annotations_dir, layout_name.replace('.png', '.json'))
|
||||||
|
|
||||||
# 加载原始PIL图像,以支持滑动窗口
|
# Load original PIL image to support sliding window
|
||||||
layout_image = Image.open(layout_path).convert('L')
|
layout_image = Image.open(layout_path).convert('L')
|
||||||
|
|
||||||
# 加载标注信息
|
# Load annotation information
|
||||||
if not os.path.exists(annotation_path):
|
if not os.path.exists(annotation_path):
|
||||||
continue
|
continue
|
||||||
with open(annotation_path, 'r') as f:
|
with open(annotation_path, 'r') as f:
|
||||||
annotation = json.load(f)
|
annotation = json.load(f)
|
||||||
|
|
||||||
# 按模板对真实标注进行分组
|
# Group ground truth annotations by template
|
||||||
gt_by_template = {os.path.basename(box['template']): [] for box in annotation.get('boxes', [])}
|
gt_by_template = {os.path.basename(box['template']): [] for box in annotation.get('boxes', [])}
|
||||||
for box in annotation.get('boxes', []):
|
for box in annotation.get('boxes', []):
|
||||||
gt_by_template[os.path.basename(box['template'])].append(box)
|
gt_by_template[os.path.basename(box['template'])].append(box)
|
||||||
|
|
||||||
# 遍历每个模板,在当前版图上进行匹配
|
# Iterate through each template and perform matching on current layout
|
||||||
for template_path in template_paths:
|
for template_path in template_paths:
|
||||||
template_name = os.path.basename(template_path)
|
template_name = os.path.basename(template_path)
|
||||||
template_image = Image.open(template_path).convert('L')
|
template_image = Image.open(template_path).convert('L')
|
||||||
|
|
||||||
# (已修改) 调用新的多尺度匹配函数
|
# (Modified) Call new multi-scale matching function
|
||||||
detected = match_template_multiscale(model, layout_image, template_image, transform)
|
detected = match_template_multiscale(model, layout_image, template_image, transform)
|
||||||
|
|
||||||
gt_boxes = gt_by_template.get(template_name, [])
|
gt_boxes = gt_by_template.get(template_name, [])
|
||||||
|
|
||||||
# 计算 TP, FP, FN (这部分逻辑不变)
|
# Calculate TP, FP, FN (this logic remains unchanged)
|
||||||
matched_gt = [False] * len(gt_boxes)
|
matched_gt = [False] * len(gt_boxes)
|
||||||
tp = 0
|
tp = 0
|
||||||
if len(detected) > 0:
|
if len(detected) > 0:
|
||||||
@@ -88,14 +88,14 @@ def evaluate(model, val_dataset_dir, val_annotations_dir, template_dir):
|
|||||||
all_fp += fp
|
all_fp += fp
|
||||||
all_fn += fn
|
all_fn += fn
|
||||||
|
|
||||||
# 计算最终指标
|
# Calculate final metrics
|
||||||
precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
|
precision = all_tp / (all_tp + all_fp) if (all_tp + all_fp) > 0 else 0
|
||||||
recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
|
recall = all_tp / (all_tp + all_fn) if (all_tp + all_fn) > 0 else 0
|
||||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||||
return {'precision': precision, 'recall': recall, 'f1': f1}
|
return {'precision': precision, 'recall': recall, 'f1': f1}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="评估 RoRD 模型性能")
|
parser = argparse.ArgumentParser(description="Evaluate RoRD model performance")
|
||||||
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
|
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
|
||||||
parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR)
|
parser.add_argument('--val_dir', type=str, default=config.VAL_IMG_DIR)
|
||||||
parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR)
|
parser.add_argument('--annotations_dir', type=str, default=config.VAL_ANN_DIR)
|
||||||
@@ -105,10 +105,10 @@ if __name__ == "__main__":
|
|||||||
model = RoRD().cuda()
|
model = RoRD().cuda()
|
||||||
model.load_state_dict(torch.load(args.model_path))
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
|
|
||||||
# (已修改) 不再需要预加载数据集,直接传入路径
|
# (Modified) No longer need to preload dataset, directly pass paths
|
||||||
results = evaluate(model, args.val_dir, args.annotations_dir, args.templates_dir)
|
results = evaluate(model, args.val_dir, args.annotations_dir, args.templates_dir)
|
||||||
|
|
||||||
print("\n--- 评估结果 ---")
|
print("\n--- Evaluation Results ---")
|
||||||
print(f" 精确率 (Precision): {results['precision']:.4f}")
|
print(f" Precision: {results['precision']:.4f}")
|
||||||
print(f" 召回率 (Recall): {results['recall']:.4f}")
|
print(f" Recall: {results['recall']:.4f}")
|
||||||
print(f" F1 分数 (F1 Score): {results['f1']:.4f}")
|
print(f" F1 Score: {results['f1']:.4f}")
|
||||||
72
match.py
72
match.py
@@ -12,7 +12,7 @@ import config
|
|||||||
from models.rord import RoRD
|
from models.rord import RoRD
|
||||||
from utils.data_utils import get_transform
|
from utils.data_utils import get_transform
|
||||||
|
|
||||||
# --- 特征提取函数 (基本无变动) ---
|
# --- Feature extraction functions (unchanged) ---
|
||||||
def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
|
def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
detection_map, desc = model(image_tensor)
|
detection_map, desc = model(image_tensor)
|
||||||
@@ -24,26 +24,26 @@ def extract_keypoints_and_descriptors(model, image_tensor, kp_thresh):
|
|||||||
if len(coords) == 0:
|
if len(coords) == 0:
|
||||||
return torch.tensor([], device=device), torch.tensor([], device=device)
|
return torch.tensor([], device=device), torch.tensor([], device=device)
|
||||||
|
|
||||||
# 描述子采样
|
# Descriptor sampling
|
||||||
coords_for_grid = coords.flip(1).view(1, -1, 1, 2) # N, 2 -> 1, N, 1, 2 (x,y)
|
coords_for_grid = coords.flip(1).view(1, -1, 1, 2) # N, 2 -> 1, N, 1, 2 (x,y)
|
||||||
# 归一化到 [-1, 1]
|
# Normalize to [-1, 1]
|
||||||
coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=device) - 1
|
coords_for_grid = coords_for_grid / torch.tensor([(desc.shape[3]-1)/2, (desc.shape[2]-1)/2], device=device) - 1
|
||||||
|
|
||||||
descriptors = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T
|
descriptors = F.grid_sample(desc, coords_for_grid, align_corners=True).squeeze().T
|
||||||
descriptors = F.normalize(descriptors, p=2, dim=1)
|
descriptors = F.normalize(descriptors, p=2, dim=1)
|
||||||
|
|
||||||
# 将关键点坐标从特征图尺度转换回图像尺度
|
# Convert keypoint coordinates from feature map scale back to image scale
|
||||||
# VGG到relu4_3的下采样率为8
|
# VGG downsampling rate to relu4_3 is 8
|
||||||
keypoints = coords.flip(1) * 8.0 # x, y
|
keypoints = coords.flip(1) * 8.0 # x, y
|
||||||
|
|
||||||
return keypoints, descriptors
|
return keypoints, descriptors
|
||||||
|
|
||||||
# --- (新增) 滑动窗口特征提取函数 ---
|
# --- (New) Sliding window feature extraction function ---
|
||||||
def extract_features_sliding_window(model, large_image, transform):
|
def extract_features_sliding_window(model, large_image, transform):
|
||||||
"""
|
"""
|
||||||
使用滑动窗口从大图上提取所有关键点和描述子
|
Extract all keypoints and descriptors from large image using sliding window
|
||||||
"""
|
"""
|
||||||
print("使用滑动窗口提取大版图特征...")
|
print("Using sliding window to extract features from large layout...")
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
W, H = large_image.size
|
W, H = large_image.size
|
||||||
window_size = config.INFERENCE_WINDOW_SIZE
|
window_size = config.INFERENCE_WINDOW_SIZE
|
||||||
@@ -54,21 +54,21 @@ def extract_features_sliding_window(model, large_image, transform):
|
|||||||
|
|
||||||
for y in range(0, H, stride):
|
for y in range(0, H, stride):
|
||||||
for x in range(0, W, stride):
|
for x in range(0, W, stride):
|
||||||
# 确保窗口不越界
|
# Ensure window does not exceed boundaries
|
||||||
x_end = min(x + window_size, W)
|
x_end = min(x + window_size, W)
|
||||||
y_end = min(y + window_size, H)
|
y_end = min(y + window_size, H)
|
||||||
|
|
||||||
# 裁剪窗口
|
# Crop window
|
||||||
patch = large_image.crop((x, y, x_end, y_end))
|
patch = large_image.crop((x, y, x_end, y_end))
|
||||||
|
|
||||||
# 预处理
|
# Preprocess
|
||||||
patch_tensor = transform(patch).unsqueeze(0).to(device)
|
patch_tensor = transform(patch).unsqueeze(0).to(device)
|
||||||
|
|
||||||
# 提取特征
|
# Extract features
|
||||||
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, config.KEYPOINT_THRESHOLD)
|
kps, descs = extract_keypoints_and_descriptors(model, patch_tensor, config.KEYPOINT_THRESHOLD)
|
||||||
|
|
||||||
if len(kps) > 0:
|
if len(kps) > 0:
|
||||||
# 将局部坐标转换为全局坐标
|
# Convert local coordinates to global coordinates
|
||||||
kps[:, 0] += x
|
kps[:, 0] += x
|
||||||
kps[:, 1] += y
|
kps[:, 1] += y
|
||||||
all_kps.append(kps)
|
all_kps.append(kps)
|
||||||
@@ -77,11 +77,11 @@ def extract_features_sliding_window(model, large_image, transform):
|
|||||||
if not all_kps:
|
if not all_kps:
|
||||||
return torch.tensor([], device=device), torch.tensor([], device=device)
|
return torch.tensor([], device=device), torch.tensor([], device=device)
|
||||||
|
|
||||||
print(f"大版图特征提取完毕,共找到 {sum(len(k) for k in all_kps)} 个关键点。")
|
print(f"Large layout feature extraction completed, found {sum(len(k) for k in all_kps)} keypoints in total.")
|
||||||
return torch.cat(all_kps, dim=0), torch.cat(all_descs, dim=0)
|
return torch.cat(all_kps, dim=0), torch.cat(all_descs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
# --- 互近邻匹配 (无变动) ---
|
# --- Mutual nearest neighbor matching (unchanged) ---
|
||||||
def mutual_nearest_neighbor(descs1, descs2):
|
def mutual_nearest_neighbor(descs1, descs2):
|
||||||
if len(descs1) == 0 or len(descs2) == 0:
|
if len(descs1) == 0 or len(descs2) == 0:
|
||||||
return torch.empty((0, 2), dtype=torch.int64)
|
return torch.empty((0, 2), dtype=torch.int64)
|
||||||
@@ -93,26 +93,26 @@ def mutual_nearest_neighbor(descs1, descs2):
|
|||||||
matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1)
|
matches = torch.stack([ids1[mask], nn12.indices[mask]], dim=1)
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
# --- (已修改) 多尺度、多实例匹配主函数 ---
|
# --- (Modified) Multi-scale, multi-instance matching main function ---
|
||||||
def match_template_multiscale(model, layout_image, template_image, transform):
|
def match_template_multiscale(model, layout_image, template_image, transform):
|
||||||
"""
|
"""
|
||||||
在不同尺度下搜索模板,并检测多个实例
|
Search for template at different scales and detect multiple instances
|
||||||
"""
|
"""
|
||||||
# 1. 对大版图使用滑动窗口提取全部特征
|
# 1. Use sliding window to extract all features from large layout
|
||||||
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform)
|
layout_kps, layout_descs = extract_features_sliding_window(model, layout_image, transform)
|
||||||
|
|
||||||
if len(layout_kps) < config.MIN_INLIERS:
|
if len(layout_kps) < config.MIN_INLIERS:
|
||||||
print("从大版图中提取的关键点过少,无法进行匹配。")
|
print("Too few keypoints extracted from large layout, cannot perform matching.")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
found_instances = []
|
found_instances = []
|
||||||
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
|
active_layout_mask = torch.ones(len(layout_kps), dtype=bool, device=layout_kps.device)
|
||||||
|
|
||||||
# 2. 多实例迭代检测
|
# 2. Multi-instance iterative detection
|
||||||
while True:
|
while True:
|
||||||
current_active_indices = torch.nonzero(active_layout_mask).squeeze(1)
|
current_active_indices = torch.nonzero(active_layout_mask).squeeze(1)
|
||||||
|
|
||||||
# 如果剩余活动关键点过少,则停止
|
# Stop if remaining active keypoints are too few
|
||||||
if len(current_active_indices) < config.MIN_INLIERS:
|
if len(current_active_indices) < config.MIN_INLIERS:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -121,28 +121,28 @@ def match_template_multiscale(model, layout_image, template_image, transform):
|
|||||||
|
|
||||||
best_match_info = {'inliers': 0, 'H': None, 'src_pts': None, 'dst_pts': None, 'mask': None}
|
best_match_info = {'inliers': 0, 'H': None, 'src_pts': None, 'dst_pts': None, 'mask': None}
|
||||||
|
|
||||||
# 3. 图像金字塔:遍历模板的每个尺度
|
# 3. Image pyramid: iterate through each scale of template
|
||||||
print("在新尺度下搜索模板...")
|
print("Searching for template at new scale...")
|
||||||
for scale in config.PYRAMID_SCALES:
|
for scale in config.PYRAMID_SCALES:
|
||||||
W, H = template_image.size
|
W, H = template_image.size
|
||||||
new_W, new_H = int(W * scale), int(H * scale)
|
new_W, new_H = int(W * scale), int(H * scale)
|
||||||
|
|
||||||
# 缩放模板
|
# Scale template
|
||||||
scaled_template = template_image.resize((new_W, new_H), Image.LANCZOS)
|
scaled_template = template_image.resize((new_W, new_H), Image.LANCZOS)
|
||||||
template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device)
|
template_tensor = transform(scaled_template).unsqueeze(0).to(layout_kps.device)
|
||||||
|
|
||||||
# 提取缩放后模板的特征
|
# Extract features from scaled template
|
||||||
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, config.KEYPOINT_THRESHOLD)
|
template_kps, template_descs = extract_keypoints_and_descriptors(model, template_tensor, config.KEYPOINT_THRESHOLD)
|
||||||
|
|
||||||
if len(template_kps) < 4: continue
|
if len(template_kps) < 4: continue
|
||||||
|
|
||||||
# 匹配当前尺度的模板和活动状态的版图特征
|
# Match current scale template with active layout features
|
||||||
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
|
matches = mutual_nearest_neighbor(template_descs, current_layout_descs)
|
||||||
|
|
||||||
if len(matches) < 4: continue
|
if len(matches) < 4: continue
|
||||||
|
|
||||||
# RANSAC
|
# RANSAC
|
||||||
# 注意:模板关键点坐标需要还原到原始尺寸,才能计算正确的H
|
# Note: template keypoint coordinates need to be restored to original size to calculate correct H
|
||||||
src_pts = template_kps[matches[:, 0]].cpu().numpy() / scale
|
src_pts = template_kps[matches[:, 0]].cpu().numpy() / scale
|
||||||
dst_pts_indices = current_active_indices[matches[:, 1]]
|
dst_pts_indices = current_active_indices[matches[:, 1]]
|
||||||
dst_pts = layout_kps[dst_pts_indices].cpu().numpy()
|
dst_pts = layout_kps[dst_pts_indices].cpu().numpy()
|
||||||
@@ -152,9 +152,9 @@ def match_template_multiscale(model, layout_image, template_image, transform):
|
|||||||
if H is not None and mask.sum() > best_match_info['inliers']:
|
if H is not None and mask.sum() > best_match_info['inliers']:
|
||||||
best_match_info = {'inliers': mask.sum(), 'H': H, 'mask': mask, 'scale': scale, 'dst_pts': dst_pts}
|
best_match_info = {'inliers': mask.sum(), 'H': H, 'mask': mask, 'scale': scale, 'dst_pts': dst_pts}
|
||||||
|
|
||||||
# 4. 如果在所有尺度中找到了最佳匹配,则记录并屏蔽
|
# 4. If best match found across all scales, record and mask
|
||||||
if best_match_info['inliers'] > config.MIN_INLIERS:
|
if best_match_info['inliers'] > config.MIN_INLIERS:
|
||||||
print(f"找到一个匹配实例!内点数: {best_match_info['inliers']}, 使用的模板尺度: {best_match_info['scale']:.2f}x")
|
print(f"Found a matching instance! Inliers: {best_match_info['inliers']}, Template scale used: {best_match_info['scale']:.2f}x")
|
||||||
|
|
||||||
inlier_mask = best_match_info['mask'].ravel().astype(bool)
|
inlier_mask = best_match_info['mask'].ravel().astype(bool)
|
||||||
inlier_layout_kps = best_match_info['dst_pts'][inlier_mask]
|
inlier_layout_kps = best_match_info['dst_pts'][inlier_mask]
|
||||||
@@ -165,15 +165,15 @@ def match_template_multiscale(model, layout_image, template_image, transform):
|
|||||||
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': best_match_info['H']}
|
instance = {'x': int(x_min), 'y': int(y_min), 'width': int(x_max - x_min), 'height': int(y_max - y_min), 'homography': best_match_info['H']}
|
||||||
found_instances.append(instance)
|
found_instances.append(instance)
|
||||||
|
|
||||||
# 屏蔽已匹配区域的关键点,以便检测下一个实例
|
# Mask keypoints in matched region to detect next instance
|
||||||
kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1]
|
kp_x, kp_y = layout_kps[:, 0], layout_kps[:, 1]
|
||||||
region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max)
|
region_mask = (kp_x >= x_min) & (kp_x <= x_max) & (kp_y >= y_min) & (kp_y <= y_max)
|
||||||
active_layout_mask[region_mask] = False
|
active_layout_mask[region_mask] = False
|
||||||
|
|
||||||
print(f"剩余活动关键点: {active_layout_mask.sum()}")
|
print(f"Remaining active keypoints: {active_layout_mask.sum()}")
|
||||||
else:
|
else:
|
||||||
# 如果在所有尺度下都找不到好的匹配,则结束搜索
|
# If no good match found across all scales, end search
|
||||||
print("在所有尺度下均未找到新的匹配实例,搜索结束。")
|
print("No new matching instances found across all scales, search ended.")
|
||||||
break
|
break
|
||||||
|
|
||||||
return found_instances
|
return found_instances
|
||||||
@@ -186,11 +186,11 @@ def visualize_matches(layout_path, bboxes, output_path):
|
|||||||
cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
cv2.rectangle(layout_img, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
||||||
cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
cv2.putText(layout_img, f"Match {i+1}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||||||
cv2.imwrite(output_path, layout_img)
|
cv2.imwrite(output_path, layout_img)
|
||||||
print(f"可视化结果已保存至: {output_path}")
|
print(f"Visualization result saved to: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="使用 RoRD 进行多尺度模板匹配")
|
parser = argparse.ArgumentParser(description="Multi-scale template matching using RoRD")
|
||||||
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
|
parser.add_argument('--model_path', type=str, default=config.MODEL_PATH)
|
||||||
parser.add_argument('--layout', type=str, required=True)
|
parser.add_argument('--layout', type=str, required=True)
|
||||||
parser.add_argument('--template', type=str, required=True)
|
parser.add_argument('--template', type=str, required=True)
|
||||||
@@ -207,7 +207,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform)
|
detected_bboxes = match_template_multiscale(model, layout_image, template_image, transform)
|
||||||
|
|
||||||
print("\n检测到的边界框:")
|
print("\nDetected bounding boxes:")
|
||||||
for bbox in detected_bboxes:
|
for bbox in detected_bboxes:
|
||||||
print(bbox)
|
print(bbox)
|
||||||
|
|
||||||
|
|||||||
@@ -7,18 +7,18 @@ from torchvision import models
|
|||||||
class RoRD(nn.Module):
|
class RoRD(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""
|
"""
|
||||||
修复后的 RoRD 模型。
|
Repaired RoRD model.
|
||||||
- 实现了共享骨干网络,以提高计算效率和减少内存占用。
|
- Implements shared backbone network to improve computational efficiency and reduce memory usage.
|
||||||
- 确保检测头和描述子头使用相同尺寸的特征图。
|
- Ensures detection head and descriptor head use feature maps of the same size.
|
||||||
"""
|
"""
|
||||||
super(RoRD, self).__init__()
|
super(RoRD, self).__init__()
|
||||||
|
|
||||||
vgg16_features = models.vgg16(pretrained=False).features
|
vgg16_features = models.vgg16(pretrained=False).features
|
||||||
|
|
||||||
# 共享骨干网络 - 只使用到 relu4_3,确保特征图尺寸一致
|
# Shared backbone network - only uses up to relu4_3 to ensure consistent feature map dimensions
|
||||||
self.backbone = nn.Sequential(*list(vgg16_features.children())[:23])
|
self.backbone = nn.Sequential(*list(vgg16_features.children())[:23])
|
||||||
|
|
||||||
# 检测头
|
# Detection head
|
||||||
self.detection_head = nn.Sequential(
|
self.detection_head = nn.Sequential(
|
||||||
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
@@ -28,7 +28,7 @@ class RoRD(nn.Module):
|
|||||||
nn.Sigmoid()
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 描述子头
|
# Descriptor head
|
||||||
self.descriptor_head = nn.Sequential(
|
self.descriptor_head = nn.Sequential(
|
||||||
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
@@ -39,10 +39,10 @@ class RoRD(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# 共享特征提取
|
# Shared feature extraction
|
||||||
features = self.backbone(x)
|
features = self.backbone(x)
|
||||||
|
|
||||||
# 检测器和描述子使用相同的特征图
|
# Detector and descriptor use the same feature maps
|
||||||
detection_map = self.detection_head(features)
|
detection_map = self.detection_head(features)
|
||||||
descriptors = self.descriptor_head(features)
|
descriptors = self.descriptor_head(features)
|
||||||
|
|
||||||
|
|||||||
140
train.py
140
train.py
@@ -12,14 +12,14 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
# 导入项目模块
|
# Import project modules
|
||||||
import config
|
import config
|
||||||
from models.rord import RoRD
|
from models.rord import RoRD
|
||||||
from utils.data_utils import get_transform
|
from utils.data_utils import get_transform
|
||||||
|
|
||||||
# 设置日志记录
|
# Setup logging
|
||||||
def setup_logging(save_dir):
|
def setup_logging(save_dir):
|
||||||
"""设置训练日志记录"""
|
"""Setup training logging"""
|
||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
os.makedirs(save_dir)
|
os.makedirs(save_dir)
|
||||||
|
|
||||||
@@ -34,14 +34,14 @@ def setup_logging(save_dir):
|
|||||||
)
|
)
|
||||||
return logging.getLogger(__name__)
|
return logging.getLogger(__name__)
|
||||||
|
|
||||||
# --- (已修改) 训练专用数据集类 ---
|
# --- (Modified) Training-specific dataset class ---
|
||||||
class ICLayoutTrainingDataset(Dataset):
|
class ICLayoutTrainingDataset(Dataset):
|
||||||
def __init__(self, image_dir, patch_size=256, transform=None, scale_range=(1.0, 1.0)):
|
def __init__(self, image_dir, patch_size=256, transform=None, scale_range=(1.0, 1.0)):
|
||||||
self.image_dir = image_dir
|
self.image_dir = image_dir
|
||||||
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
|
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.scale_range = scale_range # 新增尺度范围参数
|
self.scale_range = scale_range # New scale range parameter
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.image_paths)
|
return len(self.image_paths)
|
||||||
@@ -51,47 +51,47 @@ class ICLayoutTrainingDataset(Dataset):
|
|||||||
image = Image.open(img_path).convert('L')
|
image = Image.open(img_path).convert('L')
|
||||||
W, H = image.size
|
W, H = image.size
|
||||||
|
|
||||||
# --- 新增:尺度抖动数据增强 ---
|
# --- New: Scale jittering data augmentation ---
|
||||||
# 1. 随机选择一个缩放比例
|
# 1. Randomly select a scaling factor
|
||||||
scale = np.random.uniform(self.scale_range[0], self.scale_range[1])
|
scale = np.random.uniform(self.scale_range[0], self.scale_range[1])
|
||||||
# 2. 根据缩放比例计算需要从原图裁剪的尺寸
|
# 2. Calculate crop size from original image based on scaling factor
|
||||||
crop_size = int(self.patch_size / scale)
|
crop_size = int(self.patch_size / scale)
|
||||||
|
|
||||||
# 确保裁剪尺寸不超过图像边界
|
# 确保裁剪尺寸不超过图像边界
|
||||||
if crop_size > min(W, H):
|
if crop_size > min(W, H):
|
||||||
crop_size = min(W, H)
|
crop_size = min(W, H)
|
||||||
|
|
||||||
# 3. 随机裁剪
|
# 3. Random cropping
|
||||||
x = np.random.randint(0, W - crop_size + 1)
|
x = np.random.randint(0, W - crop_size + 1)
|
||||||
y = np.random.randint(0, H - crop_size + 1)
|
y = np.random.randint(0, H - crop_size + 1)
|
||||||
patch = image.crop((x, y, x + crop_size, y + crop_size))
|
patch = image.crop((x, y, x + crop_size, y + crop_size))
|
||||||
|
|
||||||
# 4. 将裁剪出的图像块缩放回标准的 patch_size
|
# 4. Resize cropped patch back to standard patch_size
|
||||||
patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS)
|
patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS)
|
||||||
# --- 尺度抖动结束 ---
|
# --- Scale jittering end ---
|
||||||
|
|
||||||
# --- 新增:额外的数据增强 ---
|
# --- New: Additional data augmentation ---
|
||||||
# 亮度调整
|
# Brightness adjustment
|
||||||
if np.random.random() < 0.5:
|
if np.random.random() < 0.5:
|
||||||
brightness_factor = np.random.uniform(0.8, 1.2)
|
brightness_factor = np.random.uniform(0.8, 1.2)
|
||||||
patch = patch.point(lambda x: int(x * brightness_factor))
|
patch = patch.point(lambda x: int(x * brightness_factor))
|
||||||
|
|
||||||
# 对比度调整
|
# Contrast adjustment
|
||||||
if np.random.random() < 0.5:
|
if np.random.random() < 0.5:
|
||||||
contrast_factor = np.random.uniform(0.8, 1.2)
|
contrast_factor = np.random.uniform(0.8, 1.2)
|
||||||
patch = patch.point(lambda x: int(((x - 128) * contrast_factor) + 128))
|
patch = patch.point(lambda x: int(((x - 128) * contrast_factor) + 128))
|
||||||
|
|
||||||
# 添加噪声
|
# Add noise
|
||||||
if np.random.random() < 0.3:
|
if np.random.random() < 0.3:
|
||||||
patch_np = np.array(patch, dtype=np.float32)
|
patch_np = np.array(patch, dtype=np.float32)
|
||||||
noise = np.random.normal(0, 5, patch_np.shape)
|
noise = np.random.normal(0, 5, patch_np.shape)
|
||||||
patch_np = np.clip(patch_np + noise, 0, 255)
|
patch_np = np.clip(patch_np + noise, 0, 255)
|
||||||
patch = Image.fromarray(patch_np.astype(np.uint8))
|
patch = Image.fromarray(patch_np.astype(np.uint8))
|
||||||
# --- 额外数据增强结束 ---
|
# --- Additional data augmentation end ---
|
||||||
|
|
||||||
patch_np = np.array(patch)
|
patch_np = np.array(patch)
|
||||||
|
|
||||||
# 实现8个方向的离散几何变换 (这部分逻辑不变)
|
# Implement 8-direction discrete geometric transformations (this logic remains unchanged)
|
||||||
theta_deg = np.random.choice([0, 90, 180, 270])
|
theta_deg = np.random.choice([0, 90, 180, 270])
|
||||||
is_mirrored = np.random.choice([True, False])
|
is_mirrored = np.random.choice([True, False])
|
||||||
cx, cy = self.patch_size / 2.0, self.patch_size / 2.0
|
cx, cy = self.patch_size / 2.0, self.patch_size / 2.0
|
||||||
@@ -117,57 +117,57 @@ class ICLayoutTrainingDataset(Dataset):
|
|||||||
H_tensor = torch.from_numpy(H[:2, :]).float()
|
H_tensor = torch.from_numpy(H[:2, :]).float()
|
||||||
return patch, transformed_patch, H_tensor
|
return patch, transformed_patch, H_tensor
|
||||||
|
|
||||||
# --- 特征图变换与损失函数 (改进版) ---
|
# --- (Modified) Feature map transformation and loss functions (improved version) ---
|
||||||
def warp_feature_map(feature_map, H_inv):
|
def warp_feature_map(feature_map, H_inv):
|
||||||
B, C, H, W = feature_map.size()
|
B, C, H, W = feature_map.size()
|
||||||
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
|
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
|
||||||
return F.grid_sample(feature_map, grid, align_corners=False)
|
return F.grid_sample(feature_map, grid, align_corners=False)
|
||||||
|
|
||||||
def compute_detection_loss(det_original, det_rotated, H):
|
def compute_detection_loss(det_original, det_rotated, H):
|
||||||
"""改进的检测损失:使用BCE损失替代MSE"""
|
"""Improved detection loss: use BCE loss instead of MSE"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
|
H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
|
||||||
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
|
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
|
||||||
|
|
||||||
# 使用BCE损失,更适合二分类问题
|
# Use BCE loss, more suitable for binary classification problems
|
||||||
bce_loss = F.binary_cross_entropy(det_original, warped_det_rotated)
|
bce_loss = F.binary_cross_entropy(det_original, warped_det_rotated)
|
||||||
|
|
||||||
# 添加平滑L1损失作为辅助
|
# Add smooth L1 loss as auxiliary
|
||||||
smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det_rotated)
|
smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det_rotated)
|
||||||
|
|
||||||
return bce_loss + 0.1 * smooth_l1_loss
|
return bce_loss + 0.1 * smooth_l1_loss
|
||||||
|
|
||||||
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
|
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
|
||||||
"""IC版图专用几何感知描述子损失:编码曼哈顿几何特征"""
|
"""IC layout-specific geometric-aware descriptor loss: encodes Manhattan geometric features"""
|
||||||
B, C, H_feat, W_feat = desc_original.size()
|
B, C, H_feat, W_feat = desc_original.size()
|
||||||
|
|
||||||
# 曼哈顿几何感知采样:重点采样边缘和角点区域
|
# Manhattan geometric-aware sampling: focus on edge and corner regions
|
||||||
num_samples = 200
|
num_samples = 200
|
||||||
|
|
||||||
# 生成曼哈顿对齐的采样网格(水平和垂直优先)
|
# Generate Manhattan-aligned sampling grid (horizontal and vertical priority)
|
||||||
h_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
|
h_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
|
||||||
w_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
|
w_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
|
||||||
|
|
||||||
# 增加曼哈顿方向的采样密度
|
# Increase sampling density in Manhattan directions
|
||||||
manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)])
|
manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)])
|
||||||
manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords])
|
manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords])
|
||||||
manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1).unsqueeze(0).repeat(B, 1, 1)
|
manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1).unsqueeze(0).repeat(B, 1, 1)
|
||||||
|
|
||||||
# 采样anchor点
|
# Sample anchor points
|
||||||
anchor = F.grid_sample(desc_original, manhattan_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
anchor = F.grid_sample(desc_original, manhattan_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
# 计算对应的正样本点
|
# Calculate corresponding positive samples
|
||||||
coords_hom = torch.cat([manhattan_coords, torch.ones(B, manhattan_coords.size(1), 1, device=manhattan_coords.device)], dim=2)
|
coords_hom = torch.cat([manhattan_coords, torch.ones(B, manhattan_coords.size(1), 1, device=manhattan_coords.device)], dim=2)
|
||||||
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
|
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
|
||||||
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
|
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
|
||||||
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
# IC版图专用负样本策略:考虑重复结构
|
# IC layout-specific negative sample strategy: consider repetitive structures
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# 1. 几何感知的负样本:曼哈顿变换后的不同区域
|
# 1. Geometric-aware negative samples: different regions after Manhattan transformation
|
||||||
neg_coords = []
|
neg_coords = []
|
||||||
for b in range(B):
|
for b in range(B):
|
||||||
# 生成曼哈顿变换后的坐标(90度旋转等)
|
# Generate coordinates after Manhattan transformation (90-degree rotation, etc.)
|
||||||
angles = [0, 90, 180, 270]
|
angles = [0, 90, 180, 270]
|
||||||
for angle in angles:
|
for angle in angles:
|
||||||
if angle != 0:
|
if angle != 0:
|
||||||
@@ -181,55 +181,55 @@ def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
|
|||||||
|
|
||||||
neg_coords = torch.stack(neg_coords[:B*num_samples//2]).reshape(B, -1, 2)
|
neg_coords = torch.stack(neg_coords[:B*num_samples//2]).reshape(B, -1, 2)
|
||||||
|
|
||||||
# 2. 特征空间困难负样本
|
# 2. Feature space hard negative samples
|
||||||
negative_candidates = F.grid_sample(desc_rotated, neg_coords, align_corners=False).squeeze(2).transpose(1, 2)
|
negative_candidates = F.grid_sample(desc_rotated, neg_coords, align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
# 3. 曼哈顿距离约束的困难样本选择
|
# 3. Manhattan distance constrained hard sample selection
|
||||||
anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
|
anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
|
||||||
negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
|
negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
|
||||||
|
|
||||||
# 使用曼哈顿距离而非欧氏距离
|
# Use Manhattan distance instead of Euclidean distance
|
||||||
manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3)
|
manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3)
|
||||||
hard_indices = torch.topk(manhattan_dist, k=anchor.size(1)//2, largest=False)[1]
|
hard_indices = torch.topk(manhattan_dist, k=anchor.size(1)//2, largest=False)[1]
|
||||||
negative = torch.gather(negative_candidates, 1, hard_indices)
|
negative = torch.gather(negative_candidates, 1, hard_indices)
|
||||||
|
|
||||||
# IC版图专用的几何一致性损失
|
# IC layout-specific geometric consistency loss
|
||||||
# 1. 曼哈顿方向一致性损失
|
# 1. Manhattan direction consistency loss
|
||||||
manhattan_loss = 0
|
manhattan_loss = 0
|
||||||
for i in range(anchor.size(1)):
|
for i in range(anchor.size(1)):
|
||||||
# 计算水平和垂直方向的几何一致性
|
# Calculate geometric consistency in horizontal and vertical directions
|
||||||
anchor_norm = F.normalize(anchor[:, i], p=2, dim=1)
|
anchor_norm = F.normalize(anchor[:, i], p=2, dim=1)
|
||||||
positive_norm = F.normalize(positive[:, i], p=2, dim=1)
|
positive_norm = F.normalize(positive[:, i], p=2, dim=1)
|
||||||
|
|
||||||
# 鼓励描述子对曼哈顿变换不变
|
# Encourage descriptor invariance to Manhattan transformations
|
||||||
cos_sim = torch.sum(anchor_norm * positive_norm, dim=1)
|
cos_sim = torch.sum(anchor_norm * positive_norm, dim=1)
|
||||||
manhattan_loss += torch.mean(1 - cos_sim)
|
manhattan_loss += torch.mean(1 - cos_sim)
|
||||||
|
|
||||||
# 2. 稀疏性正则化(IC版图特征稀疏)
|
# 2. Sparsity regularization (IC layout features are sparse)
|
||||||
sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive))
|
sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive))
|
||||||
|
|
||||||
# 3. 二值化特征距离(处理二值化输入)
|
# 3. Binary feature distance (handles binary input)
|
||||||
binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive)))
|
binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive)))
|
||||||
|
|
||||||
# 综合损失
|
# Combined loss
|
||||||
triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean') # 使用L1距离
|
triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean') # Use L1 distance
|
||||||
geometric_triplet = triplet_loss(anchor, positive, negative)
|
geometric_triplet = triplet_loss(anchor, positive, negative)
|
||||||
|
|
||||||
return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss
|
return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss
|
||||||
|
|
||||||
# --- (已修改) 主函数与命令行接口 ---
|
# --- (Modified) Main function and command-line interface ---
|
||||||
def main(args):
|
def main(args):
|
||||||
# 设置日志记录
|
# Setup logging
|
||||||
logger = setup_logging(args.save_dir)
|
logger = setup_logging(args.save_dir)
|
||||||
|
|
||||||
logger.info("--- 开始训练 RoRD 模型 ---")
|
logger.info("--- Starting RoRD model training ---")
|
||||||
logger.info(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
|
logger.info(f"Training parameters: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
|
||||||
logger.info(f"数据目录: {args.data_dir}")
|
logger.info(f"Data directory: {args.data_dir}")
|
||||||
logger.info(f"保存目录: {args.save_dir}")
|
logger.info(f"Save directory: {args.save_dir}")
|
||||||
|
|
||||||
transform = get_transform()
|
transform = get_transform()
|
||||||
|
|
||||||
# 在数据集初始化时传入尺度抖动范围
|
# Pass scale jittering range during dataset initialization
|
||||||
dataset = ICLayoutTrainingDataset(
|
dataset = ICLayoutTrainingDataset(
|
||||||
args.data_dir,
|
args.data_dir,
|
||||||
patch_size=config.PATCH_SIZE,
|
patch_size=config.PATCH_SIZE,
|
||||||
@@ -237,35 +237,35 @@ def main(args):
|
|||||||
scale_range=config.SCALE_JITTER_RANGE
|
scale_range=config.SCALE_JITTER_RANGE
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"数据集大小: {len(dataset)}")
|
logger.info(f"Dataset size: {len(dataset)}")
|
||||||
|
|
||||||
# 分割训练集和验证集
|
# Split training and validation sets
|
||||||
train_size = int(0.8 * len(dataset))
|
train_size = int(0.8 * len(dataset))
|
||||||
val_size = len(dataset) - train_size
|
val_size = len(dataset) - train_size
|
||||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||||
|
|
||||||
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
|
logger.info(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
|
||||||
|
|
||||||
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
||||||
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
model = RoRD().cuda()
|
model = RoRD().cuda()
|
||||||
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
|
logger.info(f"Model parameter count: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
||||||
|
|
||||||
# 添加学习率调度器
|
# Add learning rate scheduler
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
optimizer, mode='min', factor=0.5, patience=5
|
optimizer, mode='min', factor=0.5, patience=5
|
||||||
)
|
)
|
||||||
|
|
||||||
# 早停机制
|
# Early stopping mechanism
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
patience = 10
|
patience = 10
|
||||||
|
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
# 训练阶段
|
# Training phase
|
||||||
model.train()
|
model.train()
|
||||||
total_train_loss = 0
|
total_train_loss = 0
|
||||||
total_det_loss = 0
|
total_det_loss = 0
|
||||||
@@ -284,7 +284,7 @@ def main(args):
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# 梯度裁剪,防止梯度爆炸
|
# Gradient clipping to prevent gradient explosion
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@@ -300,7 +300,7 @@ def main(args):
|
|||||||
avg_det_loss = total_det_loss / len(train_dataloader)
|
avg_det_loss = total_det_loss / len(train_dataloader)
|
||||||
avg_desc_loss = total_desc_loss / len(train_dataloader)
|
avg_desc_loss = total_desc_loss / len(train_dataloader)
|
||||||
|
|
||||||
# 验证阶段
|
# Validation phase
|
||||||
model.eval()
|
model.eval()
|
||||||
total_val_loss = 0
|
total_val_loss = 0
|
||||||
total_val_det_loss = 0
|
total_val_det_loss = 0
|
||||||
@@ -325,20 +325,20 @@ def main(args):
|
|||||||
avg_val_det_loss = total_val_det_loss / len(val_dataloader)
|
avg_val_det_loss = total_val_det_loss / len(val_dataloader)
|
||||||
avg_val_desc_loss = total_val_desc_loss / len(val_dataloader)
|
avg_val_desc_loss = total_val_desc_loss / len(val_dataloader)
|
||||||
|
|
||||||
# 学习率调度
|
# Learning rate scheduling
|
||||||
scheduler.step(avg_val_loss)
|
scheduler.step(avg_val_loss)
|
||||||
|
|
||||||
logger.info(f"--- Epoch {epoch+1} 完成 ---")
|
logger.info(f"--- Epoch {epoch+1} completed ---")
|
||||||
logger.info(f"训练 - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}")
|
logger.info(f"Training - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}")
|
||||||
logger.info(f"验证 - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}")
|
logger.info(f"Validation - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}")
|
||||||
logger.info(f"学习率: {optimizer.param_groups[0]['lr']:.2e}")
|
logger.info(f"Learning rate: {optimizer.param_groups[0]['lr']:.2e}")
|
||||||
|
|
||||||
# 早停检查
|
# Early stopping check
|
||||||
if avg_val_loss < best_val_loss:
|
if avg_val_loss < best_val_loss:
|
||||||
best_val_loss = avg_val_loss
|
best_val_loss = avg_val_loss
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
|
|
||||||
# 保存最佳模型
|
# Save best model
|
||||||
if not os.path.exists(args.save_dir):
|
if not os.path.exists(args.save_dir):
|
||||||
os.makedirs(args.save_dir)
|
os.makedirs(args.save_dir)
|
||||||
save_path = os.path.join(args.save_dir, 'rord_model_best.pth')
|
save_path = os.path.join(args.save_dir, 'rord_model_best.pth')
|
||||||
@@ -353,14 +353,14 @@ def main(args):
|
|||||||
'epochs': args.epochs
|
'epochs': args.epochs
|
||||||
}
|
}
|
||||||
}, save_path)
|
}, save_path)
|
||||||
logger.info(f"最佳模型已保存至: {save_path}")
|
logger.info(f"Best model saved to: {save_path}")
|
||||||
else:
|
else:
|
||||||
patience_counter += 1
|
patience_counter += 1
|
||||||
if patience_counter >= patience:
|
if patience_counter >= patience:
|
||||||
logger.info(f"早停触发!{patience} 个epoch没有改善")
|
logger.info(f"Early stopping triggered! No improvement for {patience} epochs")
|
||||||
break
|
break
|
||||||
|
|
||||||
# 保存最终模型
|
# Save final model
|
||||||
save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
|
save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
|
||||||
torch.save({
|
torch.save({
|
||||||
'epoch': args.epochs,
|
'epoch': args.epochs,
|
||||||
@@ -373,11 +373,11 @@ def main(args):
|
|||||||
'epochs': args.epochs
|
'epochs': args.epochs
|
||||||
}
|
}
|
||||||
}, save_path)
|
}, save_path)
|
||||||
logger.info(f"最终模型已保存至: {save_path}")
|
logger.info(f"Final model saved to: {save_path}")
|
||||||
logger.info("训练完成!")
|
logger.info("Training completed!")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
|
parser = argparse.ArgumentParser(description="Train RoRD model")
|
||||||
parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR)
|
parser.add_argument('--data_dir', type=str, default=config.LAYOUT_DIR)
|
||||||
parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR)
|
parser.add_argument('--save_dir', type=str, default=config.SAVE_DIR)
|
||||||
parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS)
|
parser.add_argument('--epochs', type=int, default=config.NUM_EPOCHS)
|
||||||
|
|||||||
@@ -3,12 +3,12 @@ from .transforms import SobelTransform
|
|||||||
|
|
||||||
def get_transform():
|
def get_transform():
|
||||||
"""
|
"""
|
||||||
获取统一的图像预处理管道。
|
Get unified image preprocessing pipeline.
|
||||||
确保训练、评估和推理使用完全相同的预处理。
|
Ensure training, evaluation, and inference use exactly the same preprocessing.
|
||||||
"""
|
"""
|
||||||
return transforms.Compose([
|
return transforms.Compose([
|
||||||
SobelTransform(), # 应用 Sobel 边缘检测
|
SobelTransform(), # Apply Sobel edge detection
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # 适配 VGG 的三通道输入
|
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # Adapt to VGG's three-channel input
|
||||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
])
|
])
|
||||||
@@ -5,13 +5,13 @@ from PIL import Image
|
|||||||
class SobelTransform:
|
class SobelTransform:
|
||||||
def __call__(self, image):
|
def __call__(self, image):
|
||||||
"""
|
"""
|
||||||
应用 Sobel 边缘检测,增强 IC 版图的几何边界。
|
Apply Sobel edge detection to enhance geometric boundaries of IC layouts.
|
||||||
|
|
||||||
参数:
|
Args:
|
||||||
image (PIL.Image): 输入图像(灰度图)。
|
image (PIL.Image): Input image (grayscale).
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
PIL.Image: 边缘增强后的图像。
|
PIL.Image: Edge-enhanced image.
|
||||||
"""
|
"""
|
||||||
img_np = np.array(image)
|
img_np = np.array(image)
|
||||||
sobelx = cv2.Sobel(img_np, cv2.CV_64F, 1, 0, ksize=3)
|
sobelx = cv2.Sobel(img_np, cv2.CV_64F, 1, 0, ksize=3)
|
||||||
|
|||||||
Reference in New Issue
Block a user