add midtern report and change data source
This commit is contained in:
253
tools/diffusion/generate_diffusion_data.py
Normal file
253
tools/diffusion/generate_diffusion_data.py
Normal file
@@ -0,0 +1,253 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
一键生成扩散数据的脚本:
|
||||
1. 基于原始数据训练扩散模型
|
||||
2. 使用训练好的模型生成相似图像
|
||||
3. 更新配置文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import yaml
|
||||
import argparse
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""设置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def train_diffusion_model(data_dir, model_dir, logger, **train_kwargs):
|
||||
"""训练扩散模型"""
|
||||
logger.info("开始训练扩散模型...")
|
||||
|
||||
# 构建训练命令
|
||||
cmd = [
|
||||
sys.executable, "tools/diffusion/ic_layout_diffusion.py", "train",
|
||||
"--data_dir", data_dir,
|
||||
"--output_dir", model_dir,
|
||||
"--image_size", str(train_kwargs.get("image_size", 256)),
|
||||
"--batch_size", str(train_kwargs.get("batch_size", 8)),
|
||||
"--epochs", str(train_kwargs.get("epochs", 100)),
|
||||
"--lr", str(train_kwargs.get("lr", 1e-4)),
|
||||
"--timesteps", str(train_kwargs.get("timesteps", 1000)),
|
||||
"--num_samples", str(train_kwargs.get("num_samples", 50)),
|
||||
"--save_interval", str(train_kwargs.get("save_interval", 10))
|
||||
]
|
||||
|
||||
if train_kwargs.get("augment", False):
|
||||
cmd.append("--augment")
|
||||
|
||||
# 执行训练
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
logger.error(f"扩散模型训练失败: {result.stderr}")
|
||||
return False
|
||||
|
||||
logger.info("扩散模型训练完成")
|
||||
return True
|
||||
|
||||
|
||||
def generate_samples(model_dir, output_dir, num_samples, logger, **gen_kwargs):
|
||||
"""生成样本"""
|
||||
logger.info(f"开始生成 {num_samples} 个样本...")
|
||||
|
||||
# 查找最终模型
|
||||
model_path = Path(model_dir) / "diffusion_final.pth"
|
||||
if not model_path.exists():
|
||||
# 如果没有最终模型,查找最新的检查点
|
||||
checkpoints = list(Path(model_dir).glob("diffusion_epoch_*.pth"))
|
||||
if checkpoints:
|
||||
model_path = max(checkpoints, key=lambda x: int(x.stem.split('_')[-1]))
|
||||
else:
|
||||
logger.error(f"在 {model_dir} 中找不到模型检查点")
|
||||
return False
|
||||
|
||||
logger.info(f"使用模型: {model_path}")
|
||||
|
||||
# 构建生成命令
|
||||
cmd = [
|
||||
sys.executable, "tools/diffusion/ic_layout_diffusion.py", "generate",
|
||||
"--checkpoint", str(model_path),
|
||||
"--output_dir", output_dir,
|
||||
"--num_samples", str(num_samples),
|
||||
"--image_size", str(gen_kwargs.get("image_size", 256)),
|
||||
"--timesteps", str(gen_kwargs.get("timesteps", 1000))
|
||||
]
|
||||
|
||||
# 执行生成
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
logger.error(f"样本生成失败: {result.stderr}")
|
||||
return False
|
||||
|
||||
logger.info("样本生成完成")
|
||||
return True
|
||||
|
||||
|
||||
def update_config(config_path, output_dir, ratio, logger):
|
||||
"""更新配置文件"""
|
||||
logger.info(f"更新配置文件: {config_path}")
|
||||
|
||||
# 读取配置
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 确保必要的结构存在
|
||||
if 'synthetic' not in config:
|
||||
config['synthetic'] = {}
|
||||
|
||||
# 更新扩散配置
|
||||
config['synthetic']['enabled'] = True
|
||||
config['synthetic']['ratio'] = 0.0 # 禁用程序化合成
|
||||
|
||||
if 'diffusion' not in config['synthetic']:
|
||||
config['synthetic']['diffusion'] = {}
|
||||
|
||||
config['synthetic']['diffusion']['enabled'] = True
|
||||
config['synthetic']['diffusion']['png_dir'] = output_dir
|
||||
config['synthetic']['diffusion']['ratio'] = ratio
|
||||
|
||||
# 保存配置
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
logger.info(f"配置文件更新完成,扩散数据比例: {ratio}")
|
||||
|
||||
|
||||
def validate_generated_data(output_dir, logger):
|
||||
"""验证生成的数据"""
|
||||
logger.info("验证生成的数据...")
|
||||
|
||||
output_path = Path(output_dir)
|
||||
if not output_path.exists():
|
||||
logger.error(f"输出目录不存在: {output_dir}")
|
||||
return False
|
||||
|
||||
# 统计生成的图像
|
||||
png_files = list(output_path.glob("*.png"))
|
||||
if not png_files:
|
||||
logger.error("没有找到生成的PNG图像")
|
||||
return False
|
||||
|
||||
logger.info(f"验证通过,生成了 {len(png_files)} 个图像")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="一键生成扩散数据管线")
|
||||
parser.add_argument("--config", type=str, required=True, help="配置文件路径")
|
||||
parser.add_argument("--data_dir", type=str, help="原始数据目录(覆盖配置文件)")
|
||||
parser.add_argument("--model_dir", type=str, default="models/diffusion", help="扩散模型保存目录")
|
||||
parser.add_argument("--output_dir", type=str, default="data/diffusion_generated", help="生成数据保存目录")
|
||||
parser.add_argument("--num_samples", type=int, default=200, help="生成的样本数量")
|
||||
parser.add_argument("--ratio", type=float, default=0.3, help="扩散数据在训练中的比例")
|
||||
parser.add_argument("--skip_training", action="store_true", help="跳过训练,直接生成")
|
||||
parser.add_argument("--model_checkpoint", type=str, help="指定模型检查点路径(skip_training时使用)")
|
||||
|
||||
# 训练参数
|
||||
parser.add_argument("--epochs", type=int, default=100, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=8, help="批次大小")
|
||||
parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
|
||||
parser.add_argument("--image_size", type=int, default=256, help="图像尺寸")
|
||||
parser.add_argument("--augment", action="store_true", help="启用数据增强")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置日志
|
||||
logger = setup_logging()
|
||||
|
||||
# 读取配置文件获取数据目录
|
||||
config_path = Path(args.config)
|
||||
if not config_path.exists():
|
||||
logger.error(f"配置文件不存在: {config_path}")
|
||||
return False
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 确定数据目录
|
||||
if args.data_dir:
|
||||
data_dir = args.data_dir
|
||||
else:
|
||||
# 从配置文件获取数据目录
|
||||
config_dir = config_path.parent
|
||||
layout_dir = config.get('paths', {}).get('layout_dir', 'data/layouts')
|
||||
data_dir = str(config_dir / layout_dir)
|
||||
|
||||
data_path = Path(data_dir)
|
||||
if not data_path.exists():
|
||||
logger.error(f"数据目录不存在: {data_path}")
|
||||
return False
|
||||
|
||||
logger.info(f"使用数据目录: {data_path}")
|
||||
logger.info(f"模型保存目录: {args.model_dir}")
|
||||
logger.info(f"生成数据目录: {args.output_dir}")
|
||||
logger.info(f"生成样本数量: {args.num_samples}")
|
||||
logger.info(f"训练比例: {args.ratio}")
|
||||
|
||||
# 1. 训练扩散模型(如果需要)
|
||||
if not args.skip_training:
|
||||
success = train_diffusion_model(
|
||||
data_dir=data_dir,
|
||||
model_dir=args.model_dir,
|
||||
logger=logger,
|
||||
image_size=args.image_size,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
lr=args.lr,
|
||||
num_samples=args.num_samples,
|
||||
augment=args.augment
|
||||
)
|
||||
if not success:
|
||||
logger.error("扩散模型训练失败")
|
||||
return False
|
||||
else:
|
||||
logger.info("跳过训练步骤")
|
||||
|
||||
# 2. 生成样本
|
||||
success = generate_samples(
|
||||
model_dir=args.model_dir,
|
||||
output_dir=args.output_dir,
|
||||
num_samples=args.num_samples,
|
||||
logger=logger,
|
||||
image_size=args.image_size
|
||||
)
|
||||
if not success:
|
||||
logger.error("样本生成失败")
|
||||
return False
|
||||
|
||||
# 3. 验证生成的数据
|
||||
if not validate_generated_data(args.output_dir, logger):
|
||||
logger.error("数据验证失败")
|
||||
return False
|
||||
|
||||
# 4. 更新配置文件
|
||||
update_config(
|
||||
config_path=args.config,
|
||||
output_dir=args.output_dir,
|
||||
ratio=args.ratio,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
logger.info("=== 扩散数据生成管线完成 ===")
|
||||
logger.info(f"生成数据位置: {args.output_dir}")
|
||||
logger.info(f"配置文件已更新: {args.config}")
|
||||
logger.info(f"扩散数据比例: {args.ratio}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user