Files
Geo-Layout-Transformer/scripts/visualize_attention.py
2025-09-03 23:01:00 +08:00

65 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# scripts/visualize_attention.py
import argparse
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from src.utils.config_loader import load_config
from src.models.geo_layout_transformer import GeoLayoutTransformer
from src.utils.logging import get_logger
def main():
parser = argparse.ArgumentParser(description="可视化来自已训练模型的注意力图。")
parser.add_argument("--config-file", required=True, help="模型配置文件的路径。")
parser.add_argument("--model-path", required=True, help="已训练模型检查点的路径。")
parser.add_argument("--patch-data", required=True, help="区块数据样本(.pt 文件)的路径。")
args = parser.parse_args()
logger = get_logger("Attention_Visualizer")
logger.info("这是一个用于注意力可视化的占位符脚本。")
logger.info("完整的实现需要加载一个训练好的模型、一个数据样本,然后提取注意力权重。")
# 1. 加载配置和模型
# logger.info("正在加载模型...")
# config = load_config(args.config_file)
# model = GeoLayoutTransformer(config)
# model.load_state_dict(torch.load(args.model_path))
# model.eval()
# 2. 加载一个数据样本
# logger.info(f"正在加载数据样本从 {args.patch_data}")
# sample_data = torch.load(args.patch_data)
# 3. 注册钩子Hook到模型中以提取注意力权重
# 这是一个复杂的过程,需要访问 nn.MultiheadAttention 模块的前向传播过程。
# attention_weights = []
# def hook(module, input, output):
# # output[1] 是注意力权重
# attention_weights.append(output[1])
# model.transformer_core.transformer_encoder.layers[0].self_attn.register_forward_hook(hook)
# 4. 运行一次前向传播以获取权重
# logger.info("正在运行前向传播...")
# with torch.no_grad():
# # 模型需要修改以支持返回注意力权重,或者通过钩子获取
# _ = model(sample_data)
# 5. 绘制注意力图
# if attention_weights:
# logger.info("正在绘制注意力图...")
# # attention_weights[0] 的形状是 [batch_size, num_heads, seq_len, seq_len]
# # 我们取第一项,并在所有头上取平均值
# avg_attention = attention_weights[0][0].mean(dim=0).cpu().numpy()
# plt.figure(figsize=(10, 10))
# sns.heatmap(avg_attention, cmap='viridis')
# plt.title("区块之间的平均注意力图")
# plt.xlabel("区块索引")
# plt.ylabel("区块索引")
# plt.show()
# else:
# logger.warning("未能提取注意力权重。")
if __name__ == "__main__":
main()