add midtern report and change data source

This commit is contained in:
Jiao77
2025-11-09 18:02:40 +08:00
parent 030b9f6804
commit 8ed12915a5
24 changed files with 4230 additions and 369 deletions

View File

@@ -0,0 +1,253 @@
#!/usr/bin/env python3
"""
一键生成扩散数据的脚本:
1. 基于原始数据训练扩散模型
2. 使用训练好的模型生成相似图像
3. 更新配置文件
"""
import os
import sys
import yaml
import argparse
import subprocess
from pathlib import Path
import logging
def setup_logging():
"""设置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
return logging.getLogger(__name__)
def train_diffusion_model(data_dir, model_dir, logger, **train_kwargs):
"""训练扩散模型"""
logger.info("开始训练扩散模型...")
# 构建训练命令
cmd = [
sys.executable, "tools/diffusion/ic_layout_diffusion.py", "train",
"--data_dir", data_dir,
"--output_dir", model_dir,
"--image_size", str(train_kwargs.get("image_size", 256)),
"--batch_size", str(train_kwargs.get("batch_size", 8)),
"--epochs", str(train_kwargs.get("epochs", 100)),
"--lr", str(train_kwargs.get("lr", 1e-4)),
"--timesteps", str(train_kwargs.get("timesteps", 1000)),
"--num_samples", str(train_kwargs.get("num_samples", 50)),
"--save_interval", str(train_kwargs.get("save_interval", 10))
]
if train_kwargs.get("augment", False):
cmd.append("--augment")
# 执行训练
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
logger.error(f"扩散模型训练失败: {result.stderr}")
return False
logger.info("扩散模型训练完成")
return True
def generate_samples(model_dir, output_dir, num_samples, logger, **gen_kwargs):
"""生成样本"""
logger.info(f"开始生成 {num_samples} 个样本...")
# 查找最终模型
model_path = Path(model_dir) / "diffusion_final.pth"
if not model_path.exists():
# 如果没有最终模型,查找最新的检查点
checkpoints = list(Path(model_dir).glob("diffusion_epoch_*.pth"))
if checkpoints:
model_path = max(checkpoints, key=lambda x: int(x.stem.split('_')[-1]))
else:
logger.error(f"{model_dir} 中找不到模型检查点")
return False
logger.info(f"使用模型: {model_path}")
# 构建生成命令
cmd = [
sys.executable, "tools/diffusion/ic_layout_diffusion.py", "generate",
"--checkpoint", str(model_path),
"--output_dir", output_dir,
"--num_samples", str(num_samples),
"--image_size", str(gen_kwargs.get("image_size", 256)),
"--timesteps", str(gen_kwargs.get("timesteps", 1000))
]
# 执行生成
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
logger.error(f"样本生成失败: {result.stderr}")
return False
logger.info("样本生成完成")
return True
def update_config(config_path, output_dir, ratio, logger):
"""更新配置文件"""
logger.info(f"更新配置文件: {config_path}")
# 读取配置
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 确保必要的结构存在
if 'synthetic' not in config:
config['synthetic'] = {}
# 更新扩散配置
config['synthetic']['enabled'] = True
config['synthetic']['ratio'] = 0.0 # 禁用程序化合成
if 'diffusion' not in config['synthetic']:
config['synthetic']['diffusion'] = {}
config['synthetic']['diffusion']['enabled'] = True
config['synthetic']['diffusion']['png_dir'] = output_dir
config['synthetic']['diffusion']['ratio'] = ratio
# 保存配置
with open(config_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
logger.info(f"配置文件更新完成,扩散数据比例: {ratio}")
def validate_generated_data(output_dir, logger):
"""验证生成的数据"""
logger.info("验证生成的数据...")
output_path = Path(output_dir)
if not output_path.exists():
logger.error(f"输出目录不存在: {output_dir}")
return False
# 统计生成的图像
png_files = list(output_path.glob("*.png"))
if not png_files:
logger.error("没有找到生成的PNG图像")
return False
logger.info(f"验证通过,生成了 {len(png_files)} 个图像")
return True
def main():
parser = argparse.ArgumentParser(description="一键生成扩散数据管线")
parser.add_argument("--config", type=str, required=True, help="配置文件路径")
parser.add_argument("--data_dir", type=str, help="原始数据目录(覆盖配置文件)")
parser.add_argument("--model_dir", type=str, default="models/diffusion", help="扩散模型保存目录")
parser.add_argument("--output_dir", type=str, default="data/diffusion_generated", help="生成数据保存目录")
parser.add_argument("--num_samples", type=int, default=200, help="生成的样本数量")
parser.add_argument("--ratio", type=float, default=0.3, help="扩散数据在训练中的比例")
parser.add_argument("--skip_training", action="store_true", help="跳过训练,直接生成")
parser.add_argument("--model_checkpoint", type=str, help="指定模型检查点路径skip_training时使用")
# 训练参数
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=8, help="批次大小")
parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
parser.add_argument("--image_size", type=int, default=256, help="图像尺寸")
parser.add_argument("--augment", action="store_true", help="启用数据增强")
args = parser.parse_args()
# 设置日志
logger = setup_logging()
# 读取配置文件获取数据目录
config_path = Path(args.config)
if not config_path.exists():
logger.error(f"配置文件不存在: {config_path}")
return False
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 确定数据目录
if args.data_dir:
data_dir = args.data_dir
else:
# 从配置文件获取数据目录
config_dir = config_path.parent
layout_dir = config.get('paths', {}).get('layout_dir', 'data/layouts')
data_dir = str(config_dir / layout_dir)
data_path = Path(data_dir)
if not data_path.exists():
logger.error(f"数据目录不存在: {data_path}")
return False
logger.info(f"使用数据目录: {data_path}")
logger.info(f"模型保存目录: {args.model_dir}")
logger.info(f"生成数据目录: {args.output_dir}")
logger.info(f"生成样本数量: {args.num_samples}")
logger.info(f"训练比例: {args.ratio}")
# 1. 训练扩散模型(如果需要)
if not args.skip_training:
success = train_diffusion_model(
data_dir=data_dir,
model_dir=args.model_dir,
logger=logger,
image_size=args.image_size,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
num_samples=args.num_samples,
augment=args.augment
)
if not success:
logger.error("扩散模型训练失败")
return False
else:
logger.info("跳过训练步骤")
# 2. 生成样本
success = generate_samples(
model_dir=args.model_dir,
output_dir=args.output_dir,
num_samples=args.num_samples,
logger=logger,
image_size=args.image_size
)
if not success:
logger.error("样本生成失败")
return False
# 3. 验证生成的数据
if not validate_generated_data(args.output_dir, logger):
logger.error("数据验证失败")
return False
# 4. 更新配置文件
update_config(
config_path=args.config,
output_dir=args.output_dir,
ratio=args.ratio,
logger=logger
)
logger.info("=== 扩散数据生成管线完成 ===")
logger.info(f"生成数据位置: {args.output_dir}")
logger.info(f"配置文件已更新: {args.config}")
logger.info(f"扩散数据比例: {args.ratio}")
return True
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)