# scripts/visualize_attention.py import argparse import torch import matplotlib.pyplot as plt import seaborn as sns from src.utils.config_loader import load_config from src.models.geo_layout_transformer import GeoLayoutTransformer from src.utils.logging import get_logger def main(): parser = argparse.ArgumentParser(description="可视化来自已训练模型的注意力图。") parser.add_argument("--config-file", required=True, help="模型配置文件的路径。") parser.add_argument("--model-path", required=True, help="已训练模型检查点的路径。") parser.add_argument("--patch-data", required=True, help="区块数据样本(.pt 文件)的路径。") args = parser.parse_args() logger = get_logger("Attention_Visualizer") logger.info("这是一个用于注意力可视化的占位符脚本。") logger.info("完整的实现需要加载一个训练好的模型、一个数据样本,然后提取注意力权重。") # 1. 加载配置和模型 # logger.info("正在加载模型...") # config = load_config(args.config_file) # model = GeoLayoutTransformer(config) # model.load_state_dict(torch.load(args.model_path)) # model.eval() # 2. 加载一个数据样本 # logger.info(f"正在加载数据样本从 {args.patch_data}") # sample_data = torch.load(args.patch_data) # 3. 注册钩子(Hook)到模型中以提取注意力权重 # 这是一个复杂的过程,需要访问 nn.MultiheadAttention 模块的前向传播过程。 # attention_weights = [] # def hook(module, input, output): # # output[1] 是注意力权重 # attention_weights.append(output[1]) # model.transformer_core.transformer_encoder.layers[0].self_attn.register_forward_hook(hook) # 4. 运行一次前向传播以获取权重 # logger.info("正在运行前向传播...") # with torch.no_grad(): # # 模型需要修改以支持返回注意力权重,或者通过钩子获取 # _ = model(sample_data) # 5. 绘制注意力图 # if attention_weights: # logger.info("正在绘制注意力图...") # # attention_weights[0] 的形状是 [batch_size, num_heads, seq_len, seq_len] # # 我们取第一项,并在所有头上取平均值 # avg_attention = attention_weights[0][0].mean(dim=0).cpu().numpy() # plt.figure(figsize=(10, 10)) # sns.heatmap(avg_attention, cmap='viridis') # plt.title("区块之间的平均注意力图") # plt.xlabel("区块索引") # plt.ylabel("区块索引") # plt.show() # else: # logger.warning("未能提取注意力权重。") if __name__ == "__main__": main()