Files
RoRD-Layout-Recognation/tools/setup_diffusion_training.py
2025-11-09 18:02:40 +08:00

275 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
一键设置扩散训练流程的脚本
此脚本帮助用户:
1. 检查环境
2. 生成扩散数据
3. 配置训练参数
4. 启动训练
"""
import sys
import argparse
import yaml
import subprocess
from pathlib import Path
import logging
def setup_logging():
"""设置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
return logging.getLogger(__name__)
def check_environment(logger):
"""检查运行环境"""
logger.info("检查运行环境...")
# 检查Python包
required_packages = ['torch', 'torchvision', 'numpy', 'PIL', 'yaml']
missing_packages = []
for package in required_packages:
try:
__import__(package)
logger.info(f"{package} 已安装")
except ImportError:
missing_packages.append(package)
logger.warning(f"{package} 未安装")
if missing_packages:
logger.error(f"缺少必需的包: {missing_packages}")
logger.info("请安装缺少的包pip install " + " ".join(missing_packages))
return False
# 检查CUDA
try:
import torch
if torch.cuda.is_available():
logger.info(f"✓ CUDA 可用,设备数量: {torch.cuda.device_count()}")
else:
logger.warning("✗ CUDA 不可用将使用CPU训练速度较慢")
except Exception as e:
logger.warning(f"无法检查CUDA状态: {e}")
logger.info("环境检查完成")
return True
def create_sample_config(config_path, logger):
"""创建示例配置文件"""
logger.info("创建示例配置文件...")
config = {
'training': {
'learning_rate': 5e-5,
'batch_size': 8,
'num_epochs': 50,
'patch_size': 256,
'scale_jitter_range': [0.8, 1.2]
},
'model': {
'fpn': {
'enabled': True,
'out_channels': 256,
'levels': [2, 3, 4],
'norm': 'bn'
},
'backbone': {
'name': 'vgg16',
'pretrained': False
},
'attention': {
'enabled': False,
'type': 'none',
'places': []
}
},
'paths': {
'layout_dir': 'data/layouts', # 原始数据目录
'save_dir': 'models/rord',
'val_img_dir': 'data/val/images',
'val_ann_dir': 'data/val/annotations',
'template_dir': 'data/templates',
'model_path': 'models/rord/rord_model_best.pth'
},
'data_sources': {
'real': {
'enabled': True,
'ratio': 0.7 # 70% 真实数据
},
'diffusion': {
'enabled': True,
'model_dir': 'models/diffusion',
'png_dir': 'data/diffusion_generated',
'ratio': 0.3, # 30% 扩散数据
'training': {
'epochs': 100,
'batch_size': 8,
'lr': 1e-4,
'image_size': 256,
'timesteps': 1000,
'augment': True
},
'generation': {
'num_samples': 200,
'timesteps': 1000
}
}
},
'logging': {
'use_tensorboard': True,
'log_dir': 'runs',
'experiment_name': 'diffusion_training'
}
}
with open(config_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
logger.info(f"示例配置文件已创建: {config_path}")
return True
def setup_directories(logger):
"""创建必要的目录"""
logger.info("创建目录结构...")
directories = [
'data/layouts',
'data/diffusion_generated',
'models/diffusion',
'models/rord',
'runs',
'logs'
]
for directory in directories:
Path(directory).mkdir(parents=True, exist_ok=True)
logger.info(f"{directory}")
logger.info("目录结构创建完成")
return True
def run_diffusion_pipeline(config_path, logger):
"""运行扩散数据生成流程"""
logger.info("运行扩散数据生成流程...")
cmd = [
sys.executable, "tools/diffusion/generate_diffusion_data.py",
"--config", config_path,
"--data_dir", "data/layouts",
"--model_dir", "models/diffusion",
"--output_dir", "data/diffusion_generated",
"--num_samples", "200",
"--ratio", "0.3"
]
logger.info(f"执行命令: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
logger.error(f"扩散数据生成失败: {result.stderr}")
return False
logger.info("扩散数据生成完成")
return True
def start_training(config_path, logger):
"""启动训练"""
logger.info("启动模型训练...")
cmd = [
sys.executable, "train.py",
"--config", config_path
]
logger.info(f"执行命令: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=False) # 实时显示输出
if result.returncode != 0:
logger.error("训练失败")
return False
logger.info("训练完成")
return True
def main():
parser = argparse.ArgumentParser(description="一键设置扩散训练流程")
parser.add_argument("--config", type=str, default="configs/diffusion_config.yaml", help="配置文件路径")
parser.add_argument("--skip_env_check", action="store_true", help="跳过环境检查")
parser.add_argument("--skip_diffusion", action="store_true", help="跳过扩散数据生成")
parser.add_argument("--skip_training", action="store_true", help="跳过模型训练")
parser.add_argument("--only_check", action="store_true", help="仅检查环境")
args = parser.parse_args()
logger = setup_logging()
logger.info("=== RoRD 扩散训练流程设置 ===")
# 1. 环境检查
if not args.skip_env_check:
if not check_environment(logger):
logger.error("环境检查失败")
return False
if args.only_check:
logger.info("环境检查完成")
return True
# 2. 创建目录结构
if not setup_directories(logger):
logger.error("目录创建失败")
return False
# 3. 创建示例配置文件
config_path = Path(args.config)
if not config_path.exists():
if not create_sample_config(args.config, logger):
logger.error("配置文件创建失败")
return False
else:
logger.info(f"使用现有配置文件: {config_path}")
# 4. 运行扩散数据生成流程
if not args.skip_diffusion:
if not run_diffusion_pipeline(args.config, logger):
logger.error("扩散数据生成失败")
return False
else:
logger.info("跳过扩散数据生成")
# 5. 启动训练
if not args.skip_training:
if not start_training(args.config, logger):
logger.error("训练失败")
return False
else:
logger.info("跳过模型训练")
logger.info("=== 扩散训练流程设置完成 ===")
logger.info("您可以查看以下文件和目录:")
logger.info(f"配置文件: {args.config}")
logger.info("扩散模型: models/diffusion/")
logger.info("生成数据: data/diffusion_generated/")
logger.info("训练模型: models/rord/")
logger.info("训练日志: runs/")
return True
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)