initial commit

This commit is contained in:
Jiao77
2025-08-25 17:54:08 +08:00
commit f187abe72a
28 changed files with 1703 additions and 0 deletions

View File

@@ -0,0 +1,63 @@
import argparse
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from src.utils.config_loader import load_config
from src.models.geo_layout_transformer import GeoLayoutTransformer
from src.utils.logging import get_logger
def main():
parser = argparse.ArgumentParser(description="可视化来自已训练模型的注意力图。")
parser.add_argument("--config-file", required=True, help="模型配置文件的路径。")
parser.add_argument("--model-path", required=True, help="已训练模型检查点的路径。")
parser.add_argument("--patch-data", required=True, help="区块数据样本(.pt 文件)的路径。")
args = parser.parse_args()
logger = get_logger("Attention_Visualizer")
logger.info("这是一个用于注意力可视化的占位符脚本。")
logger.info("完整的实现需要加载一个训练好的模型、一个数据样本,然后提取注意力权重。")
# 1. 加载配置和模型
# logger.info("正在加载模型...")
# config = load_config(args.config_file)
# model = GeoLayoutTransformer(config)
# model.load_state_dict(torch.load(args.model_path))
# model.eval()
# 2. 加载一个数据样本
# logger.info(f"正在加载数据样本从 {args.patch_data}")
# sample_data = torch.load(args.patch_data)
# 3. 注册钩子Hook到模型中以提取注意力权重
# 这是一个复杂的过程,需要访问 nn.MultiheadAttention 模块的前向传播过程。
# attention_weights = []
# def hook(module, input, output):
# # output[1] 是注意力权重
# attention_weights.append(output[1])
# model.transformer_core.transformer_encoder.layers[0].self_attn.register_forward_hook(hook)
# 4. 运行一次前向传播以获取权重
# logger.info("正在运行前向传播...")
# with torch.no_grad():
# # 模型需要修改以支持返回注意力权重,或者通过钩子获取
# _ = model(sample_data)
# 5. 绘制注意力图
# if attention_weights:
# logger.info("正在绘制注意力图...")
# # attention_weights[0] 的形状是 [batch_size, num_heads, seq_len, seq_len]
# # 我们取第一项,并在所有头上取平均值
# avg_attention = attention_weights[0][0].mean(dim=0).cpu().numpy()
# plt.figure(figsize=(10, 10))
# sns.heatmap(avg_attention, cmap='viridis')
# plt.title("区块之间的平均注意力图")
# plt.xlabel("区块索引")
# plt.ylabel("区块索引")
# plt.show()
# else:
# logger.warning("未能提取注意力权重。")
if __name__ == "__main__":
main()