Visualising VLM attention

Visualising the visualiser

Computer Vision
LLMs
Where does a VLM actually look at when looking at an image?
Author

Salman Naqvi

Published

Monday, 17 November 2025

In this notebook, I attempt to visualize how a VLM places attention on any input image, and on any input instruction.

This notebook follows the fastai style guide.

Important

Some of the the cell outputs have been deleted to keep the PDF concise. Rerun the notebook to see full outputs.

Setup

from contextlib import contextmanager
import traceback
@contextmanager
def suplog():
  try: yield
  except Exception: traceback.print_exc()
with suplog(): 5/0
Traceback (most recent call last):
  File "/tmp/ipykernel_265136/256214625.py", line 5, in suplog
    try: yield
  File "/tmp/ipykernel_265136/256214625.py", line 7, in <module>
    with suplog(): 5/0
ZeroDivisionError: division by zero

This is a nifty snippet to keep errors in a notebook for future reference, but to allow myself to run all cells without halting execution.

from fastcore.all import *
if in_colab():
  ! pip install qwen-vl-utils[decord]
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoConfig
cp = 'BAAI/RoboBrain2.0-3B'
vlm = Qwen2_5_VLForConditionalGeneration.from_pretrained(cp, attn_implementation='eager', torch_dtype="float16", device_map="auto")
proc = AutoProcessor.from_pretrained(cp)
tokz = AutoTokenizer.from_pretrained(cp)
cfg = AutoConfig.from_pretrained(cp)
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
vlm.model.config._attn_implementation, vlm.visual.config._attn_implementation
('eager', 'eager')
cfg
Qwen2_5_VLConfig {
  "architectures": [
    "Qwen2_5_VLForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "image_token_id": 151655,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 128000,
  "max_window_layers": 70,
  "model_type": "qwen2_5_vl",
  "num_attention_heads": 16,
  "num_hidden_layers": 36,
  "num_key_value_heads": 2,
  "rms_norm_eps": 1e-06,
  "rope_scaling": {
    "mrope_section": [
      16,
      24,
      24
    ],
    "rope_type": "default",
    "type": "default"
  },
  "rope_theta": 1000000.0,
  "sliding_window": 32768,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.50.0",
  "use_cache": true,
  "use_sliding_window": false,
  "video_token_id": 151656,
  "vision_config": {
    "depth": 32,
    "fullatt_block_indexes": [
      7,
      15,
      23,
      31
    ],
    "hidden_act": "silu",
    "hidden_size": 1280,
    "in_channels": 3,
    "in_chans": 3,
    "intermediate_size": 3420,
    "model_type": "qwen2_5_vl",
    "num_heads": 16,
    "out_hidden_size": 2048,
    "patch_size": 14,
    "spatial_merge_size": 2,
    "spatial_patch_size": 14,
    "temporal_patch_size": 2,
    "tokens_per_second": 2,
    "torch_dtype": "bfloat16",
    "window_size": 112
  },
  "vision_end_token_id": 151653,
  "vision_start_token_id": 151652,
  "vision_token_id": 151654,
  "vocab_size": 151936
}

Getting Started

In this notebook, I’ll be using the recent RoboBrain 2.0 model, which blows other VLMs out of the water. It’s a refreshing model, since the reason it performed so well was because the team worked directly with the data.

Future note to self: if too much attention is placed on the system prompt, remove it.

from IPython.display import Image
img_p = 'http://images.cocodataset.org/val2017/000000039769.jpg'
Image(img_p)

# q = "Which way is the cat on the right facing?<think>"
q = "Describe the cat on the left.<think>"
msgs = [
  {
    'role': 'user',
    'content': [
        {'type': 'image', 'image': img_p if img_p.startswith('http') else f'file://{img_p}'},
        {'type': 'text', 'text': f'{q}'},
    ],
  },
]
txt_inp = proc.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True); txt_inp
'<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the cat on the left.<think><|im_end|>\n<|im_start|>assistant\n'

I’ll remove the system prompt as it seems the VLM places a lot of attention there. If I leave it in, the visualizations become skewed.

txt_inp = txt_inp.replace('<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n', ''); txt_inp
'<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the cat on the left.<think><|im_end|>\n<|im_start|>assistant\n'
from qwen_vl_utils import process_vision_info
img_inp, vid_inp = process_vision_info(msgs)

Over here, the image has been padded so that each dimension is a multiple of 14, which matches the kernel size.

img_inp, cfg.vision_config.patch_size
([<PIL.Image.Image image mode=RGB size=644x476>], 14)
644/14,476/14
(46.0, 34.0)
vid_inp
inps = proc(text=txt_inp, images=img_inp, videos=vid_inp, padding=True, return_tensors='pt').to('cuda'); inps
{'input_ids': tensor([[151644,    872,    198, 151652, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655,
         151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151653,
          74785,    279,   8251,    389,    279,   2115,  15757,  26865,     29,
         151645,    198, 151644,  77091,    198]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1]], device='cuda:0'), 'pixel_values': tensor([[ 0.2515,  0.3099,  0.3391,  ..., -0.6270, -0.3995, -0.4990],
        [ 0.5143,  0.2807,  0.5581,  ..., -0.2857, -0.4137, -0.2573],
        [-0.0113,  0.0909, -0.0842,  ..., -1.0963, -1.0252, -0.9683],
        ...,
        [ 1.7114,  1.6238,  1.6238,  ...,  1.1505,  1.0652,  1.0225],
        [ 1.4486,  1.5800,  1.5216,  ..., -0.3426, -0.2146,  0.2688],
        [ 1.6530,  1.6676,  1.5508,  ..., -1.0678, -0.8545, -0.8830]],
       device='cuda:0'), 'image_grid_thw': tensor([[ 1, 34, 46]], device='cuda:0')}
inps.keys()
dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'image_grid_thw'])
inps.input_ids.shape
torch.Size([1, 410])

Here I can see the tokens I might need to be aware off when visualizing the attention. If I keep them, it may skew the visualization as a lot of attention may be placed on those tokens.

Most likely, I need to be wary of the very first <|im_start|>, and set its attention to zero when visualizing.

tokz.convert_ids_to_tokens(inps.input_ids[0])
o = vlm.generate(**inps, max_new_tokens=768, do_sample=True, temperature=.7, return_dict_in_generate=True, output_attentions=True); o
proc.batch_decode(o.sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False)

Both the input prompt and the generated answer in total have this many tokens (the second dimension in the shape below.)

o.sequences.shape
torch.Size([1, 437])
trimmed_o = L(out_ids[len(in_ids):] for in_ids,out_ids in zip(inps.input_ids, o['sequences']))
proc.batch_decode(trimmed_o, skip_special_tokens=False, clean_up_tokenization_spaces=False)
['The cat on the left is a tiger-striped cat wearing a green tag. It is laying down comfortably and appears to be relaxed.<|im_end|>']

The generated answer itself has this many tokens.

trimmed_o[0].shape
torch.Size([27])

What I’m going to do next is generate the attention matrix of the model before it generated the first token, by averaging the attention of every layer.

There were so many tokens generated in the answer. Therefore, we have so many attention states.

L(o.attentions)

For each of the so many tokens generated, I can access the attention state of each of the 28 layers in the model.

L(o.attentions[0])

The input prompt had so many tokens (the second dimension).

inps.input_ids.shape
torch.Size([1, 410])

Every layer stores the attention for each of the so many tokens.

o.attentions[0][0].shape
torch.Size([1, 16, 410, 410])
with suplog(): o.attentions[0][28].shape

Looking at the output shape, there are two repeating dimensions. 424 corresponds to the input sequence length. However, it appears twice. The reason for this is that attention is all about relationships.

It’s best to think of this as a grid or embedding of sorts. Each token is mapped to every other token.

I’ll take a simple input sequence as an example ['<image>', 'the', 'cat']. To calculate the attention of this sequence, the model needs to know how important each token is to each other.

· <image> (being looked at) the (being looked at) cat (being looked at)
<image> (looking) score(image, image) score(image, the) score(image, cat)
the (looking) score(the, image) score(the, the) score(the, cat)
cat (looking) score(cat, image) score(cat, the) score(cat, cat)

Specifically speaking, the first sequence length (the rows) is the query dimension–the token that is doing the looking. The second sequence length (the columns) is the key dimension–the token that is being looked at.

So a layer with shape [1, 28, 424, 424] is a layer with 1 batch, with 28 attention heads, and a 424x424 attention matrix. An entry [i,j] tells us how much attention token i is paying to token j.

When the first token is generated, it gets appended to the input prompt. This appended form is then used to generate the next token.

I can see that the sequence length has indeed been increased by 1 after the first token has been generated.

o.attentions[1][0].shape
torch.Size([1, 16, 1, 411])

However, the last two dimensions are now not the same. Something smart is happening here: computation is being saved.

Let’s say that the last two dimensions remained at 425x425, that would mean a lot of duplicated, redundant computation.

  • Step 0 (Prompt): compute attention for 424 tokens. Matrix is 424x424.
  • Step 1 (Generate 1st Token): The sequence is now 425 tokens long. Recompute attention for all 425 tokens. Matrix would be 425x425.
  • Step 2 (Generate 2nd Token): Sequence is 426 tokens. Recalculate attention for all 426. Matrix would be 426x426.

The attention for tokens that have already been calculated, are repeatedly being calculated. That’s computationally wasteful. Instead, the model is performing what’s called KV caching. In KV caching, the model knows the attention scores for certain tokens won’t change. They can therefore be cached.

  1. Processing the Prompt
    • The model processes the initial 424 tokens and calculates the query, key, and value vectors for each token.
    • Full self-attention is performed, where every token’s query is compared against every token’s key.
    • The key and value vectors for all 424 tokens are then saved to a KV cache
  2. Generating the First New Token
    • The query vector for the single new token that is about to be generated is calculated
    • This single query is compared against the keys of all 425 tokens (424 cached keys + the key for the new token itself)

Only the single row of the attention matrix that is need to predict the next token is calculated.

I’ll begin by calculating the average attention for the prompt, and for the first layer of the model

l = o.attentions[0][0]; l.shape
torch.Size([1, 16, 410, 410])
l_attns = l.squeeze(0); l_attns.shape
torch.Size([16, 410, 410])

I squeeze because the batch size is 1. Now, I’ll average the attention scores across all heads in this layer to produce a single attention map/grid/matrix.

avg_attns = l_attns.mean(dim=0); avg_attns.shape, avg_attns.min(), avg_attns.max()
(torch.Size([410, 410]),
 tensor(0., device='cuda:0', dtype=torch.float16),
 tensor(0.9375, device='cuda:0', dtype=torch.float16))

In autoregressive models, the attention for token i is used to predict token i+1. The final token in the prompt doesn’t predict a new token within the prompt itself, so we don’t need to visualize its attention.

avg_attns[:-1].shape
torch.Size([409, 410])
cur = avg_attns[:-1].cpu().clone()

The first token is typically a special “beginning of sentence” token. Including it can skew the visualization (we don’t want the other tokens to pay attention to it). In this case, it’s <|im_start|>.

cur[1:,0] = 0.

Now since some values are 0, the attention scores for each token no longer sum to 1.

for i in range(0,10,2): print(f'Token {i}: {avg_attns[:-1][i,:].sum()}')
Token 0: 1.0
Token 2: 1.0
Token 4: 1.0
Token 6: 1.0
Token 8: 1.0
for i in range(0,10,2): print(f'Token {i}: {cur[i,:].sum()}')
Token 0: 1.0
Token 2: 0.85107421875
Token 4: 0.93994140625
Token 6: 0.927734375
Token 8: 0.966796875

I’ll renormalize.

cur[1:].shape
torch.Size([408, 410])
cur[1:].sum(-1).shape
torch.Size([408])
cur[1:].sum(-1,keepdim=True).shape
torch.Size([408, 1])
cur[1:] = cur[1:]/cur[1:].sum(-1, keepdim=True)
for i in range(0,10,2): print(f'Token {i}: {cur[i,:].sum()}')
Token 0: 1.0
Token 2: 1.0
Token 4: 0.99951171875
Token 6: 0.99951171875
Token 8: 1.0

And there I have the aggregated attention map for this layer!

Now I’ll aggregate the attention map for every layer, and then take the average aross all layers.

agg_prompt_attn = L()
for i,l in enumerate(o['attentions'][0]):
  l_attns = l.squeeze(0)
  avg_attns = l_attns.mean(dim=0)
  cur = avg_attns[:-1].cpu().clone()
  cur[1:,0] = 0
  cur[1:] = cur[1:]/cur[1:].sum(-1, keepdim=True)
  agg_prompt_attn.append(cur)
agg_prompt_attn, agg_prompt_attn[0].shape
import torch as t
agg_prompt_attn = t.stack(tuple(agg_prompt_attn)).mean(dim=0); agg_prompt_attn.shape
torch.Size([409, 410])

What I’ve just done is that for each layer, I’ve averaged the attention across the 28 heads. Then I’ve averaged the averages across all layers.

Now I want to aggregate the attentions map for generating only the 1st token, the map used to generate only the 2nd token, and so on. Rather than prompt+1st token, prompt+2nd token, and so on. The main difference here is how the vector is created.

avg_attns.shape
torch.Size([410, 410])
avg_attns[-1].shape, avg_attns[-1][1:].shape
(torch.Size([410]), torch.Size([409]))
v = t.concat((
  t.tensor([0.]), 
  avg_attns[-1][1:].cpu(), 
  t.tensor([0.]) 
  ))
v.shape
torch.Size([411])
  1. The first entry is zero, because that’s null attention.
  2. On first generation, there’s a row for each token in the prompt. So I do [-1] to get the most recent token.
  3. The attention for the last token is set to zero. This token never gets any attention since it doesn’t exist.
def aggregrate_llm_attention(attn):
  avg = L()
  for l in attn:
    l = l.squeeze(0)
    avg_attns = l.mean(dim=0)
    v = t.concat((t.tensor([0.]), avg_attns[-1][1:].cpu(), t.tensor([0.])))
    avg.append(v/v.sum())
  return t.stack(tuple(avg)).mean(dim=0)
L(o.attentions)
agg_llm_attn = L(map(aggregrate_llm_attention, o.attentions)); agg_llm_attn

Now I’ll perform a heterogenous stack on the prompt attention and the response attention. In other words, shorter vectors are padded to equal the length of the longest vector.

vecs = [t.tensor([1])]+list(agg_prompt_attn)+list(agg_llm_attn)
max_len = max(v.shape[0] for v in vecs); max_len
437
len(vecs), vecs
vecs[1].shape
torch.Size([410])

Now I’ll turn this into a PyTorch tensor.

attn_mtx = t.stack([t.concat((v, t.zeros(max_len-v.shape[0]))) for v in vecs]); attn_mtx.shape
torch.Size([437, 437])

Now it’s time to visualize! A higher gamma factor highlights lower attention values more.

?np.power
Object `np.power` not found.
enh_attn_mtx = attn_mtx.pow(1/:=5))
import matplotlib.pyplot as plt
fig,ax = plt.subplots(figsize=(5,10),dpi=150)
ax.imshow(enh_attn_mtx, vmin=enh_attn_mtx.min(), vmax=enh_attn_mtx.max(), interpolation='nearest');

The sharp spike of attention at the beginning of the graph is most likely the system prompt.

What I now want to do is visualize how the total attention on the image itself varied as each token was generated.

inp_tok_len = len(inps.input_ids[0]); inp_tok_len
410
tokz.convert_ids_to_tokens(inps.input_ids[0])

I’ll save the indicies of the first and last vision tokens.

vis_tok_idxs = tokz.convert_ids_to_tokens(inps.input_ids[0]).index('<|vision_start|>'), tokz.convert_ids_to_tokens(inps.input_ids[0]).index('<|vision_end|>'); vis_tok_idxs
(3, 395)

I’ll now sum the attentions for each generated token, over all image tokens.

attn_mtx[inp_tok_len:].shape
torch.Size([27, 437])
attn_mtx[inp_tok_len:][0].shape
torch.Size([437])
attn_mtx[inp_tok_len:][0][vis_tok_idxs[0]:vis_tok_idxs[1]].shape
torch.Size([392])
# vis_tok_attn_w = [r[vis_tok_idxs[0]:vis_tok_idxs[1]].sum().item() for i,(r,t) in enumerate(zip(attn_mtx[inp_tok_len:],o.sequences[0].tolist()))]; vis_tok_attn_w
vis_tok_attn_w = [r[vis_tok_idxs[0]:vis_tok_idxs[1]].sum().item() for r in attn_mtx[inp_tok_len:]]; vis_tok_attn_w
[0.1649223268032074,
 0.14924326539039612,
 0.24785716831684113,
 0.15989810228347778,
 0.18078750371932983,
 0.19855275750160217,
 0.2221490442752838,
 0.25423598289489746,
 0.33117902278900146,
 0.2534129023551941,
 0.23635193705558777,
 0.27973249554634094,
 0.2829405665397644,
 0.2884763479232788,
 0.20834580063819885,
 0.13748674094676971,
 0.19547797739505768,
 0.20929457247257233,
 0.2825366258621216,
 0.24532952904701233,
 0.19346649944782257,
 0.1775532066822052,
 0.16503523290157318,
 0.15991057455539703,
 0.19774067401885986,
 0.1700247824192047,
 0.10200894623994827]

Now I can visualize the sum of all attentions on all image tokens, as each new token was generated.

fig,ax = plt.subplots(figsize=(40,5), dpi=227)
ax.plot(vis_tok_attn_w)
ax.set_xticks(range(len(vis_tok_attn_w)))
ax.set_xticklabels([tokz.decode(t, add_special_tokens=False).strip() for t in trimmed_o[0].tolist()], rotation=75);

Next up, I want to visualize the attention weights of the vision encoder.

vlm.visual.blocks[0].attn
Qwen2_5_VLVisionAttention(
  (qkv): Linear(in_features=1280, out_features=3840, bias=True)
  (proj): Linear(in_features=1280, out_features=1280, bias=True)
)

The vision encoder uses SPDA attention. This implementation uses C++/CUDA under the hood, so I’ll have to use a manual implementation instead to be able to access the weights.

vlm.visual.config._attn_implementation
'eager'
vlm.visual.named_modules
<bound method Module.named_modules of Qwen2_5_VisionTransformerPretrainedModel(
  (patch_embed): Qwen2_5_VisionPatchEmbed(
    (proj): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
  )
  (rotary_pos_emb): Qwen2_5_VisionRotaryEmbedding()
  (blocks): ModuleList(
    (0-31): 32 x Qwen2_5_VLVisionBlock(
      (norm1): Qwen2RMSNorm((1280,), eps=1e-06)
      (norm2): Qwen2RMSNorm((1280,), eps=1e-06)
      (attn): Qwen2_5_VLVisionAttention(
        (qkv): Linear(in_features=1280, out_features=3840, bias=True)
        (proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (mlp): Qwen2_5_VLMLP(
        (gate_proj): Linear(in_features=1280, out_features=3420, bias=True)
        (up_proj): Linear(in_features=1280, out_features=3420, bias=True)
        (down_proj): Linear(in_features=3420, out_features=1280, bias=True)
        (act_fn): SiLU()
      )
    )
  )
  (merger): Qwen2_5_VLPatchMerger(
    (ln_q): Qwen2RMSNorm((1280,), eps=1e-06)
    (mlp): Sequential(
      (0): Linear(in_features=5120, out_features=5120, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=5120, out_features=2048, bias=True)
    )
  )
)>
with suplog():
  with t.no_grad(): x = vlm.visual(inps.pixel_values, inps.image_grid_thw, output_attentions=True)
Traceback (most recent call last):
  File "/tmp/ipykernel_265136/256214625.py", line 5, in suplog
    try: yield
  File "/tmp/ipykernel_265136/862200727.py", line 2, in <module>
    with t.no_grad(): x = vlm.visual(inps.pixel_values, inps.image_grid_thw, output_attentions=True)
  File "/home/data/Salman/miniforge3/envs/robobrain2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/data/Salman/miniforge3/envs/robobrain2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Qwen2_5_VisionTransformerPretrainedModel.forward() got an unexpected keyword argument 'output_attentions'

The output_attentions parameter doesn’t exist.

vlm.visual.named_modules
vlm.visual.blocks[0].attn??
Signature:       vlm.visual.blocks[0].attn(*args, **kwargs)

Type:            Qwen2_5_VLVisionAttention

String form:    

Qwen2_5_VLVisionAttention(

  (qkv): Linear(in_features=1280, out_features=3840, bias=True)

  (proj): Linear(in_features=1280, out_features=1280, bias=True)

)

File:            ~/miniforge3/envs/robobrain2/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Source:         

class Qwen2_5_VLVisionAttention(nn.Module):

    def __init__(self, dim: int, num_heads: int = 16) -> None:

        super().__init__()

        self.num_heads = num_heads

        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias=True)

        self.proj = nn.Linear(dim, dim)



    def forward(

        self,

        hidden_states: torch.Tensor,

        cu_seqlens: torch.Tensor,

        rotary_pos_emb: Optional[torch.Tensor] = None,

        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,

    ) -> torch.Tensor:

        seq_length = hidden_states.shape[0]

        q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)

        if position_embeddings is None:

            logger.warning_once(

                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "

                "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "

                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "

                "removed and `position_embeddings` will be mandatory."

            )

            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)

            cos = emb.cos()

            sin = emb.sin()

        else:

            cos, sin = position_embeddings

        q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)



        attention_mask = torch.full(

            [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype

        )

        for i in range(1, len(cu_seqlens)):

            attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0



        q = q.transpose(0, 1)

        k = k.transpose(0, 1)

        v = v.transpose(0, 1)

        attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)

        attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)

        attn_output = torch.matmul(attn_weights, v)

        attn_output = attn_output.transpose(0, 1)

        attn_output = attn_output.reshape(seq_length, -1)

        attn_output = self.proj(attn_output)

        return attn_output

Class docstring:

Base class for all neural network modules.



Your models should also subclass this class.



Modules can also contain other Modules, allowing to nest them in

a tree structure. You can assign the submodules as regular attributes::



    import torch.nn as nn

    import torch.nn.functional as F



    class Model(nn.Module):

        def __init__(self) -> None:

            super().__init__()

            self.conv1 = nn.Conv2d(1, 20, 5)

            self.conv2 = nn.Conv2d(20, 20, 5)



        def forward(self, x):

            x = F.relu(self.conv1(x))

            return F.relu(self.conv2(x))



Submodules assigned in this way will be registered, and will have their

parameters converted too when you call :meth:`to`, etc.



.. note::

    As per the example above, an ``__init__()`` call to the parent class

    must be made before assignment on the child.



:ivar training: Boolean represents whether this module is in training or

                evaluation mode.

:vartype training: bool

Init docstring:  Initialize internal Module state, shared by both nn.Module and ScriptModule.
vlm.visual.blocks[0].attn.*?
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.

Attaching a forward hook to this module will probably be the most straightforward method.

attn_ws = []
def get_attn_hook(mod, inp, out):
  attn_ws.append(out)
  print("="*60)
  print(f"MODULE: {type(mod).__name__}")
  print(f"Module attrs: {[attr for attr in dir(mod) if not attr.startswith('_')]}")
  print("-"*30)
  print(f"INPUT: {type(inp)} | Shape: {inp[0].shape if hasattr(inp[0], 'shape') else 'N/A'}")
  if hasattr(inp, '__len__') and len(inp) > 0:
    print(f"Input[0] attrs: {[attr for attr in dir(inp[0]) if not attr.startswith('_')]}")
  print("-"*30)
  print(f"OUTPUT: {type(out)} | Shape: {out.shape if hasattr(out, 'shape') else 'N/A'}")
  print(f"Output attrs: {[attr for attr in dir(out) if not attr.startswith('_')]}")
  print("*"*30 + " CHILDREN " + "*"*30)
  for name, child in mod.named_children():
    print(f"  └── {name}: {type(child).__name__}")
    for subname, subchild in child.named_children():
      print(f"      └── {subname}: {type(subchild).__name__}")
  print("="*60)
for n,m in vlm.visual.named_modules():
  if n.endswith('.attn'):
    print(f'attaching hook to {n}')
    m.register_forward_hook(get_attn_hook)
for n,m in vlm.visual.named_modules():
  if m._forward_hooks: print(f'Forward hooks found on module: {n} -> {m._forward_hooks}')
with t.no_grad(): x = vlm.visual(inps.pixel_values, inps.image_grid_thw)

The output is a 1564x1280 tensor. There are (644/14)·(476/14)=46·34=1564 patches. The reason why this doesn’t match the number of image tokens produced by the tokenizer is that these 1564 patches haven’t been passed through the projection layer yet.

The Qwen 2.5 VL projection layer combines patches into 2x2 blocks. This will result in, in this case, (46/2)·(34/2)=392. This matches the number of image tokens.

len(attn_ws), attn_ws[0].shape
(32, torch.Size([1564, 1280]))

Adding a hook doesn’t work as I only get the output feature vectors back, and not the intermediate attention weights. I need to monkey patch the forward method so it captures the weights.

vlm.visual??

It seems I’ll need to monkey patch something within the blocks themselves.

vlm.visual.blocks[0].attn??

I’ll need to monkey patch the forward method of the attention class.

vlm.visual.blocks[0].attn.forward??
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionAttention, apply_rotary_pos_emb_vision
import torch
from torch import nn
??Qwen2_5_VLVisionAttention
attn_ws = []
@patch
def forward(self: Qwen2_5_VLVisionAttention, hidden_states, cu_seqlens, rotary_pos_emb, position_embeddings):
  seq_length = hidden_states.shape[0]
  q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
  if position_embeddings is None:
    logger.warning_once(
      "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
      "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
      "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
      "removed and `position_embeddings` will be mandatory."
    )
    emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
    cos = emb.cos()
    sin = emb.sin()
  else:
    cos, sin = position_embeddings
  q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

  attention_mask = torch.full(
    [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
  )
  for i in range(1, len(cu_seqlens)):
    attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0

  q = q.transpose(0, 1)
  k = k.transpose(0, 1)
  v = v.transpose(0, 1)
  attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
  attn_weights = attn_weights + attention_mask
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
  attn_ws.append(attn_weights) #<< Addition
  attn_output = torch.matmul(attn_weights, v)
  attn_output = attn_output.transpose(0, 1)
  attn_output = attn_output.reshape(seq_length, -1)
  attn_output = self.proj(attn_output)
  return attn_output

Confirming patch has be introduced.

??vlm.visual.blocks[0].attn
vlm.visual.blocks[0].attn.forward??

Let’s give it a run. But first, remove the existing hooks.

for n,m in vlm.visual.named_modules():
  if hasattr(m, '_forward_hooks') and m._forward_hooks: m._forward_hooks.clear()
attn_ws = []
with t.no_grad(): x = vlm.visual(inps.pixel_values, inps.image_grid_thw)
len(attn_ws), attn_ws[0].shape
(32, torch.Size([16, 1564, 1564]))

It’s worked! The reason why the attention weights is a different shape is that the patches/feature vectors to which the attention weights correspond to haven’t pass through what’s known as the projection/pooling layer.

The RoboBrain model is based on Qwen 2.5 VL. The vision encoder in this model uses a kernel size of 14.

vlm.visual.config.patch_size
14

The image at the beginning of the notebook has dimensions, after padding, 644x476.

img_inp[0].size
(644, 476)

Therefore, the number of patches produced is (644/14)·(476/14)=46·34=1564. After the vision encoder produces the feature vector for each of these 1564 patches, the vectors are passed through the projection/pooling layer. In the Qwen 2.5 VL model, each 2x2 patch is combined, which results in (46/2)·(34/2)=392 patches. This matches the number of image tokens produced by the tokenizer.

I will now aggregate all the attention weights.

attns_per_head = l.mean(dim=0); attns_per_head.shape
torch.Size([16, 410, 410])

In the reference notebook, a normalization step occurs after taking the mean across all heads. This is because in the LLaVa implementation, the first token is a special token which doens’t need to be visualized. Here, I do not need to perform such processes, since far as I’m aware, there is no special patch token.

If there were, I would have to perform the following.

vec = attns_per_head[1:,1:].cpu()
vec/vec.sum(-1,keepdim=True)
vis_attn_mtx = t.stack([l.mean(dim=0).cpu() for l in attn_ws]).mean(dim=0); vis_attn_mtx.shape
torch.Size([1564, 1564])

Time to visualize. I’ll begin by visualizing the attention of only the first token.

vis_attn_mtx.shape
torch.Size([1564, 1564])
out_tok_len = len(o.sequences[0]); out_tok_len
437
out_tok_idxs = L(range(inp_tok_len, out_tok_len)); out_tok_idxs
(#27) [410,411,412,413,414,415,416,417,418,419,420,421,422,423,424,425,426,427,428,429...]
out_tok_len = len(out_tok_idxs); out_tok_len
27

The target token is currently the first generated token.

target_tok_idx = out_tok_idxs[0]; target_tok_idx
410
attn_mtx.shape
torch.Size([437, 437])
attn_mtx[target_tok_idx].shape
torch.Size([437])
vis_tok_idxs, attn_mtx[target_tok_idx][vis_tok_idxs[0]:vis_tok_idxs[1]].shape
((3, 395), torch.Size([392]))

I’ll obtain the attention weights over the vision tokens, and then normalize those weights.

attn_ws_over_vis_toks = attn_mtx[target_tok_idx][vis_tok_idxs[0]:vis_tok_idxs[1]]
attn_ws_over_vis_toks.min(), attn_ws_over_vis_toks.max()
(tensor(3.0008e-05), tensor(0.0208))
attn_ws_over_vis_toks = attn_ws_over_vis_toks / attn_ws_over_vis_toks.sum()
attn_ws_over_vis_toks.min(), attn_ws_over_vis_toks.max()
(tensor(0.0002), tensor(0.1263))
attn_ws_over_vis_toks.shape, vis_attn_mtx.shape
(torch.Size([392]), torch.Size([1564, 1564]))
y = L(zip(attn_ws_over_vis_toks, vis_attn_mtx)); y
(#392) [(tensor(0.0114), tensor([7.2815e-02, 2.7817e-02, 2.7298e-02,  ..., 3.5703e-05, 9.5844e-05,
        1.3089e-04], dtype=torch.float16)),(tensor(0.1002), tensor([5.5511e-02, 3.5126e-02, 2.1759e-02,  ..., 6.6400e-05, 6.7353e-05,
        8.6725e-05], dtype=torch.float16)),(tensor(0.0592), tensor([5.8105e-02, 1.9745e-02, 4.0253e-02,  ..., 4.3392e-05, 5.0604e-05,
        7.5161e-05], dtype=torch.float16)),(tensor(0.1263), tensor([4.4434e-02, 2.7466e-02, 2.9800e-02,  ..., 4.5955e-05, 4.0114e-05,
        5.9068e-05], dtype=torch.float16)),(tensor(0.0875), tensor([4.9561e-02, 2.5208e-02, 1.5625e-02,  ..., 5.1498e-05, 6.5625e-05,
        1.1742e-04], dtype=torch.float16)),(tensor(0.0013), tensor([4.0619e-02, 2.0660e-02, 1.8677e-02,  ..., 1.0329e-04, 8.6665e-05,
        1.9145e-04], dtype=torch.float16)),(tensor(0.0005), tensor([3.3173e-02, 1.9882e-02, 1.4587e-02,  ..., 5.0008e-05, 7.7903e-05,
        9.0420e-05], dtype=torch.float16)),(tensor(0.0036), tensor([1.8967e-02, 1.4717e-02, 1.7731e-02,  ..., 7.9393e-05, 4.9829e-05,
        4.2319e-05], dtype=torch.float16)),(tensor(0.0086), tensor([0.0320, 0.0208, 0.0108,  ..., 0.0001, 0.0001, 0.0002],
       dtype=torch.float16)),(tensor(0.0018), tensor([2.7313e-02, 1.5854e-02, 1.1726e-02,  ..., 1.1587e-04, 1.2165e-04,
        9.3877e-05], dtype=torch.float16)),(tensor(0.0009), tensor([1.7624e-02, 1.1581e-02, 1.1177e-02,  ..., 6.8367e-05, 9.2864e-05,
        1.7357e-04], dtype=torch.float16)),(tensor(0.0008), tensor([1.8936e-02, 1.0399e-02, 1.0033e-02,  ..., 3.6418e-05, 5.7578e-05,
        4.4882e-05], dtype=torch.float16)),(tensor(0.0015), tensor([2.3804e-02, 1.3443e-02, 1.0323e-02,  ..., 3.4571e-05, 4.4703e-05,
        1.3959e-04], dtype=torch.float16)),(tensor(0.0008), tensor([3.4149e-02, 1.6403e-02, 1.2215e-02,  ..., 2.6405e-05, 4.3452e-05,
        5.7697e-05], dtype=torch.float16)),(tensor(0.0015), tensor([2.0523e-02, 1.3504e-02, 1.1162e-02,  ..., 3.9041e-05, 4.8935e-05,
        7.4387e-05], dtype=torch.float16)),(tensor(0.0007), tensor([1.9135e-02, 1.0620e-02, 1.3245e-02,  ..., 2.6762e-05, 5.5313e-05,
        3.9697e-05], dtype=torch.float16)),(tensor(0.0004), tensor([3.3081e-02, 1.1414e-02, 3.4821e-02,  ..., 2.5630e-05, 5.0366e-05,
        7.6592e-05], dtype=torch.float16)),(tensor(0.0005), tensor([2.1408e-02, 1.3779e-02, 2.8259e-02,  ..., 1.3947e-05, 3.1412e-05,
        3.7611e-05], dtype=torch.float16)),(tensor(0.0007), tensor([2.5253e-02, 1.3268e-02, 1.7975e-02,  ..., 8.7857e-05, 1.4389e-04,
        2.6059e-04], dtype=torch.float16)),(tensor(0.0004), tensor([1.5945e-02, 7.9346e-03, 1.0849e-02,  ..., 1.1802e-05, 4.8161e-05,
        3.2246e-05], dtype=torch.float16))...]
y[0], y[0][1].shape
((tensor(0.0114),
  tensor([7.2815e-02, 2.7817e-02, 2.7298e-02,  ..., 3.5703e-05, 9.5844e-05,
          1.3089e-04], dtype=torch.float16)),
 torch.Size([1564]))

So what I have here, y, is the LLM attention weights over the image tokens, and the ViT attention weights of the image patches.

grid_sz = img_inp[0].size[0]//14, img_inp[0].size[1]//14; grid_sz
(46, 34)

I need to resize the ViT/vision encoder tensors so it matches the number of image tokens. In addition, I’ll also take the product between the resulting reshaped attention tensor and the LLM attention weights to get the final attention over the image.

y[0][1].reshape(grid_sz[0],grid_sz[1]).shape
torch.Size([46, 34])
attn_over_img = []
for w, vis_attn in y:
  vis_attn = vis_attn.reshape(grid_sz[0],grid_sz[1])
  attn_over_img.append(vis_attn*w)
len(attn_over_img), attn_over_img[0].shape
(392, torch.Size([46, 34]))
attn_over_img = t.stack(attn_over_img).sum(dim=0); attn_over_img.shape
torch.Size([46, 34])
attn_over_img = attn_over_img / attn_over_img.max()
import torch.nn.functional as F
??F.interpolate

The reason why I’m interpolating here is to align the resolution of the current attention heatmap to the original input image. In essence, I’m upscaling the attention heatmap.

attn_over_img[None,None,:,:].shape
torch.Size([1, 1, 46, 34])
attn_over_img = F.interpolate(attn_over_img[None,None,:,:], size=img_inp[0].size, mode='nearest'); attn_over_img.shape
torch.Size([1, 1, 644, 476])

mode='nearest' is an operation that could be thought of as a constrast slider when editing photos. more='nearest' prevents the resulting image from looking too flat or smooth.

import numpy as np
np_img = np.array(img_inp[0]).transpose(1,0,2); np_img.shape
(644, 476, 3)
np_img = np_img[:,:,::-1]; np_img.shape
(644, 476, 3)

I’ll now overlay the attention mask on the image.

np_img = np.float32(np_img)/255
import cv2
with suplog(): heatmap = cv2.applyColorMap(np.uint8(255*attn_over_img.numpy()), cv2.COLOMAP_HSV)
heatmap = cv2.applyColorMap(np.uint8(255*attn_over_img[0,0,...].to(t.float16).numpy()), cv2.COLORMAP_HSV); heatmap
heatmap = np.float32(heatmap)/255
with suplog(): cam = heatmap + np.float32(np_img)
cam = heatmap + np.float32(np_img)
cam /= np.max(cam)
img_with_attn = np.uint(255*cam)
img_with_attn.shape, type(img_with_attn), img_with_attn.dtype
((644, 476, 3), numpy.ndarray, dtype('uint64'))
from PIL import Image as PILImage
with suplog(): PILImage.fromarray(img_with_attn)

Now I’ll visualize the map for all tokens.

num_imgs_per_row = 8
img_ratio = img_inp[0].size[0] / img_inp[0].size[1]; img_ratio
1.3529411764705883
out_tok_len
27
num_rows = out_tok_len//num_imgs_per_row+(1 if out_tok_len%num_imgs_per_row!=0 else 0); num_rows
4
fig, axes = plt.subplots(num_rows, num_imgs_per_row, figsize=(10,(10/num_imgs_per_row)*img_ratio*num_rows), dpi=150)
plt.subplots_adjust(wspace=.05, hspace=.2)

type(img_inp[0])
PIL.Image.Image
fig, axes = plt.subplots(num_rows, num_imgs_per_row, figsize=(10,(10/num_imgs_per_row)*img_ratio*num_rows), dpi=150)
plt.subplots_adjust(wspace=.05, hspace=.2)

for i,ax in enumerate(axes.flatten()):
  if i>=out_tok_len:
    ax.axis('off')
    continue

  target_tok_idx = out_tok_idxs[i]
  attn_ws_over_vis_toks = attn_mtx[target_tok_idx][vis_tok_idxs[0]:vis_tok_idxs[1]]
  attn_ws_over_vis_toks = attn_ws_over_vis_toks / attn_ws_over_vis_toks.sum()

  attn_over_img = []
  for w,vis_attn in zip(attn_ws_over_vis_toks, vis_attn_mtx):
    vis_attn = vis_attn.reshape(grid_sz[0],grid_sz[1])
    vis_attn = vis_attn / vis_attn.max()
    attn_over_img.append(vis_attn*w)
  attn_over_img = t.stack(attn_over_img).sum(dim=0)
  attn_over_img = attn_over_img / attn_over_img.max()

  attn_over_img = F.interpolate(attn_over_img[None,None,...], size=(img_inp[0].size[1],img_inp[0].size[0]), mode='bicubic', align_corners=False).squeeze()
  attn_over_img = attn_over_img.pow(1/1.5)

  np_img = np.array(img_inp[0])[:,:,::-1]
  np_img = np.float32(np_img)/255
  hm = cv2.applyColorMap(np.uint8(255*attn_over_img.cpu().numpy()), cv2.COLORMAP_HSV)
  hm = np.float32(hm)/255
  cam = hm+np.float32(np_img)
  cam /= np.max(cam)
  img_with_attn = np.uint8(255*cam)

  ax.imshow(img_with_attn)
  ax.set_title(tokz.decode(trimmed_o[0][i], add_special_tokens=False).strip(), fontsize=7, pad=1)
  ax.axis('off')
Back to top