Post

πŸ”Ž VL-SAM Hands-on: VL-SAM 을 μ‹€μŠ΅ν•΄λ³΄μž!

πŸ”Ž VL-SAM Hands-on: VL-SAM 을 μ‹€μŠ΅ν•΄λ³΄μž!

κ°œμš”

이 ν¬μŠ€νŠΈμ—μ„œλŠ” VL-SAM의 핡심인 β€œκ°μ²΄ 이름 β†’ μœ„μΉ˜ 힌트(Attention Map) 생성 β†’ SAM 포인트 ν”„λ‘¬ν”„νŠΈβ€ 쀑, Attention Map 생성(VLM μΈ‘) μ‹€μŠ΅μ„ λ‹€λ£Ήλ‹ˆλ‹€.
(반볡 iteration 없이, 단일 패슀 μ„±κ²©μ˜ 데λͺ¨)


1) [Object Recognition] – Attention Map Generation (VLM)

핡심 아이디어: SAM에 넣을 Object Promptλ₯Ό VLM의 μ–΄ν…μ…˜ νλ¦„μœΌλ‘œλΆ€ν„° λ§Œλ“€μž!

  • a. 이미지 기반으둜 VLMμ—κ²Œ β€œμ΄λ―Έμ§€ 속 λͺ¨λ“  객체λ₯Ό λ‚˜μ—΄ν•΄β€λΌκ³  μš”μ²­ β†’ μ‘λ‹΅μ—μ„œ Tag2Text μœ μ‚¬ λ°©μ‹μœΌλ‘œ 객체 리슀트λ₯Ό 확보
  • b. 토큰 생성 κ³Όμ •μ—μ„œ λͺ¨λ“  λ ˆμ΄μ–΄/ν—€λ“œμ˜ Q, Kλ₯Ό μ €μž₯
  • c. μΆ”μΆœν•œ 객체 토큰에 λŒ€ν•΄ Q Γ— Kα΅€ β†’ causal mask 적용 β†’ Softmax ν‘œμ€€ν™” β†’ similarity matrix S
  • d. λ ˆμ΄μ–΄/ν—€λ“œ 별 κ°€μ€‘μΉ˜ W μ‚°μΆœ
    • μ˜ˆμ‹œ: W = Mean(Max(S, dim=1), dim=0)
  • e. κ°€μ€‘μΉ˜ W둜 λ³΄μ •λœ Sβ€² 계산
  • f. λ ˆμ΄μ–΄λ³„ Sβ€²λ₯Ό μ’…ν•©ν•˜μ—¬ attention flow μ‚°μΆœ
  • g. Auto-Regressive VLM의 νŠΉμ„±(μ’Œμƒλ‹¨μœΌλ‘œ collapse κ²½ν–₯)을 쀄이기 μœ„ν•΄ Regularized attention flow column μ‚¬μš©
  • Finally, μ΅œμ’… Attention Map μ™„μ„±!

2) μ‹€μŠ΅ μ½”λ“œ (Attention Map 생성)

ν™˜κ²½ κ°€μ •

  • Qwen/Qwen2.5-VL-3B-Instruct
  • 4-bit μ–‘μžν™”(bitsandbytes) μ‚¬μš© κ°€λŠ₯ μ‹œ λ©”λͺ¨λ¦¬ μ ˆμ•½
  • μ•„λž˜ μ½”λ“œλŠ” μ–΄ν…μ…˜ λ§΅ μƒμ„±κΉŒμ§€ (SAM 연동은 별도 λ‹¨κ³„μ—μ„œ μ§„ν–‰ κ°€λŠ₯)

주의

  • λΈ”λ‘œκ·Έμ— 맞좰 μ½”λ“œ νŽœμŠ€λŠ” $$$ 둜 κ°μŒŒμŠ΅λ‹ˆλ‹€. 볡뢙 ν›„ λ°”λ‘œ μ‹€ν–‰ν•˜μ„Έμš”.
  • 경둜(IMAGE_PATH)λŠ” 둜컬 ν™˜κ²½μ— 맞게 μˆ˜μ •ν•˜μ„Έμš”.

$$$ import os import re import cv2 import torch import numpy as np import matplotlib.pyplot as plt from PIL import Image

from transformers import AutoProcessor, AutoModelForVision2Seq from transformers import BitsAndBytesConfig

β€”β€”β€”β€”β€”β€”β€”β€”β€”

Config

β€”β€”β€”β€”β€”β€”β€”β€”β€”

QWEN_MODEL = β€œQwen/Qwen2.5-VL-3B-Instruct” # VRAM에 맞좰 μ‘°μ • κ°€λŠ₯ (3B ꢌμž₯) IMAGE_PATH = β€œ/home/bongo/porter_notebook/research/dog.jpg” # μ‹€μŠ΅μš© 이미지(κ°•μ•„μ§€ 1마리) TEXT_QUERY = β€œdog” DEVICE = β€œcuda” if torch.cuda.is_available() else β€œcpu”

4bit μ–‘μžν™”(κ°€λŠ₯ μ‹œ λ©”λͺ¨λ¦¬ μ ˆμ•½) β€” bnb λ―Έμ„€μΉ˜λ©΄ 이 쀄 μ‚­μ œ

bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

print(f”[Info] Device: {DEVICE}”) image_pil = Image.open(IMAGE_PATH).convert(β€œRGB”)

print(f”[Info] Loading {QWEN_MODEL}…”) processor = AutoProcessor.from_pretrained(QWEN_MODEL, trust_remote_code=True) model = AutoModelForVision2Seq.from_pretrained( QWEN_MODEL, device_map=”auto”, torch_dtype=”auto”, quantization_config=bnb, # bitsandbytes λ―Έμ„€μΉ˜ μ‹œ 제거 trust_remote_code=True, ) print(β€œ[Info] Model loaded.”)

β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”

Step (a): μ΄λ―Έμ§€μ—μ„œ 객체 리슀트 생성 (Tag2Text μœ μ‚¬)

β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”

print(β€œ\n[Step a] Generating object list from the image…”)

messages_for_obj_list = [ { β€œrole”: β€œuser”, β€œcontent”: [ {β€œtype”: β€œimage”}, {β€œtype”: β€œtext”, β€œtext”: β€œPlease analyze the image and list all the objects present.”} ] } ]

text_template = processor.apply_chat_template(messages_for_obj_list, tokenize=False, add_generation_prompt=True) inputs_for_obj_list = processor(text=[text_template], images=[image_pil], return_tensors=”pt”).to(DEVICE)

with torch.no_grad(): output = model.generate( **inputs_for_obj_list, max_new_tokens=256, # κ³Όλ„ν•œ κΈΈμ΄λŠ” λ©”λͺ¨λ¦¬β†‘ do_sample=False, temperature=0.0 )

generated_text = processor.batch_decode(output, skip_special_tokens=True)[0]

간단 νŒŒμ„œ: β€œassistant\n” 이후λ₯Ό 긁어 μ‰Όν‘œ/κ°œν–‰μœΌλ‘œ split

obj_list_raw = generated_text.split(β€œassistant\n”)[-1].strip() object_list = [obj.strip().lower() for obj in re.split(r’[,\n]’, obj_list_raw) if obj.strip()]

print(f” - Detected objects (raw): {object_list}”)

데λͺ¨μš©μœΌλ‘œ λŒ€μƒ 객체 1개만 μ‚¬μš©

if len(object_list) == 0: object_list = [TEXT_QUERY] target_object = object_list[0] print(f” - Target object: {target_object}”)

β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”

Final: μˆ˜λ™ 생성 루프λ₯Ό 톡해 디코더 μ–΄ν…μ…˜μ„ μˆ˜μ§‘ β†’ λ§΅ 근사

(주의) μ‹€μ œλ‘œλŠ” Vision 패치 μ–΄ν…μ…˜μ΄ 더 곡간 μ˜λ―Έκ°€ λͺ…ν™•ν•˜μ§€λ§Œ,

λΉŒλ“œ/λ©”λͺ¨λ¦¬ μ œμ•½ μ‹œ 디코더 μ–΄ν…μ…˜ 근사λ₯Ό μ‚¬μš©.

β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”

print(f”\n[Final Step] Generating attention map for β€˜{target_object}’ via a MANUAL generation loop…”)

prompt_for_attention = β€œPlease analyze the image and list all the objects present.” messages_for_attention = [ {β€œrole”: β€œuser”, β€œcontent”: [{β€œtype”: β€œimage”}, {β€œtype”: β€œtext”, β€œtext”: prompt_for_attention}]} ] text_template = processor.apply_chat_template(messages_for_attention, tokenize=False, add_generation_prompt=True) inputs = processor(text=[text_template], images=[image_pil], return_tensors=”pt”).to(DEVICE) input_ids = inputs.input_ids

λŒ€μƒ 토큰 ID (κ°„λ‹¨νžˆ 곡백+토큰 인코딩)

token_ids = processor.tokenizer.encode(f” {target_object}”, add_special_tokens=False) target_token_id = token_ids[0] if len(token_ids) > 0 else None

found_target_attention = None generated_token_idx = -1 past_key_values = None

λ©”λͺ¨λ¦¬ μ•ˆμ „μ„ μœ„ν•΄ 토큰 길이 μ œν•œ (예: μ΅œλŒ€ 30)

for step in range(30): with torch.no_grad(): outputs = model( input_ids=input_ids, past_key_values=past_key_values, use_cache=True, output_attentions=True, # 디코더 μ–΄ν…μ…˜ μš”μ²­ return_dict=True )

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

# λŒ€μƒ 토큰이 μƒμ„±λ˜λ©΄ μ–΄ν…μ…˜ μ €μž₯
if target_token_id is not None and next_token.item() == target_token_id and found_target_attention is None:
    print(f"  - Found '{target_object}' token at generation step {step}.")
    # outputs.attentions: Tuple[ num_layers x (B, num_heads, tgt_len, src_len) ]
    found_target_attention = outputs.attentions
    generated_token_idx = input_ids.shape[1]  # ν˜„μž¬ λ§ˆμ§€λ§‰ μœ„μΉ˜

# λ‹€μŒ 단계 μ€€λΉ„
input_ids = torch.cat([input_ids, next_token], dim=-1)
past_key_values = outputs.past_key_values

# EOSλ©΄ 쀑단
if next_token.item() == processor.tokenizer.eos_token_id:
    print("  - Reached EOS.")
    break

print(β€œ - Manual generation complete.”)

β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”

Attention Map 집계 (근사) 및 μ‹œκ°ν™”

β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”

if found_target_attention: print(f” - Aggregating attention from {len(found_target_attention)} layers…”) # λΉ„μ „ 패치 κ·Έλ¦¬λ“œ 크기 μΆ”μ •κ°’ (κ°„λ‹¨ν•œ 데λͺ¨μš©) image_size = 448 patch_size = getattr(model.config.vision_config, β€œpatch_size”, 14) grid_size = max(1, image_size // patch_size) num_patches = grid_size * grid_size

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
all_layer_attentions = []
for layer_attention in found_target_attention:
    # layer_attention: (B, num_heads, tgt_len, src_len)
    # μƒμ„±λœ 직전 ν† ν°μ˜ queryκ°€ 바라본 src 뢄포(row) 취득
    token_row = layer_attention[0, :, generated_token_idx - 1, :]     # (num_heads, src_len)
    # μ•žμͺ½ 패치 토큰에 ν•΄λ‹Ήν•œλ‹€κ³  κ°€μ •ν•˜κ³  num_patches만 취득 (근사)
    image_patch_attention = token_row[:, :num_patches]                # (num_heads, num_patches)
    layer_avg = image_patch_attention.mean(dim=0)                     # (num_patches,)
    all_layer_attentions.append(layer_avg)

final_avg = torch.stack(all_layer_attentions, dim=0).mean(dim=0)      # (num_patches,)
attention_map = final_avg.reshape(grid_size, grid_size).cpu().numpy()

# 원본 ν•΄μƒλ„λ‘œ μ—…μƒ˜ν”Œ
resized_map = cv2.resize(attention_map, (image_pil.width, image_pil.height), interpolation=cv2.INTER_CUBIC)

# μ‹œκ°ν™”
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1); plt.title("Original Image"); plt.axis('off'); plt.imshow(image_pil)
plt.subplot(1, 2, 2); plt.title(f"Aggregated Attention for '{target_object}'"); plt.axis('off')
plt.imshow(image_pil); plt.imshow(resized_map, cmap='jet', alpha=0.5)
plt.tight_layout()
plt.savefig("attention_map_result.png", dpi=200)
print("[OK] Saved attention_map_result.png") else:
print(f"  - Could not find or generate the token for '{target_object}' within the generation limit.") $$$

3) μ‹€ν–‰ 팁 & νŠΈλŸ¬λΈ”μŠˆνŒ…

  • λ©”λͺ¨λ¦¬(OOM) νšŒν”Ό
    • κ°€λŠ₯ν•˜λ©΄ 3B 체크포인트 μ‚¬μš©
    • μž…λ ₯ 이미지λ₯Ό λ¨Όμ € κΈ΄ λ³€ 384~448둜 μΆ•μ†Œ ν›„ νˆ¬μž…
    • max_new_tokensλ₯Ό 48~128 λ²”μœ„λ‘œ μ œν•œ
  • bitsandbytes(4bit) λ―Έμ„€μΉ˜/ν˜Έν™˜ 문제
    • quantization_config=bnb 쀄을 μ‚­μ œν•˜κ³  μ‹€ν–‰ (λŒ€μ‹  VRAM μ—¬μœ  ν•„μš”)
    • λ˜λŠ” pip install -U bitsandbytes (NVIDIA CUDA ν™˜κ²½μ—μ„œλ§Œ)
  • μ–΄ν…μ…˜μ΄ None으둜 μ˜€λŠ” 경우
    • λΉŒλ“œμ— 따라 generate()μ—μ„œλŠ” μ–΄ν…μ…˜μ„ λ°˜ν™˜ν•˜μ§€ μ•ŠμŒ
    • μœ„ μ½”λ“œλŠ” μˆ˜λ™ 생성 루프 + output_attentions=True둜 디코더 μ–΄ν…μ…˜μ„ μˆ˜μ§‘ (곡간적 근사)
    • 더 μ •ν™•ν•œ 곡간 히트맡이 ν•„μš”ν•˜λ©΄ λΉ„μ „ νƒ€μ›Œ μ–΄ν…μ…˜ ν›„ν‚Ή λ˜λŠ” νžˆλ“  μœ μ‚¬λ„(cosine) λ°©μ‹μœΌλ‘œ λŒ€μ²΄ κ°€λŠ₯

4) λ‹€μŒ 단계 (SAM 연동)

  • μœ„μ—μ„œ 얻은 Attention Mapμ—μ„œ Positive/Negative 포인트λ₯Ό μƒ˜ν”Œλ§
  • segment-anything의 SamPredictor.predict(point_coords, point_labels)에 전달
  • κ²°κ³Ό segmentation mask μ‹œκ°ν™”(overlay)

SAM 연동 μ˜ˆμ‹œλŠ” 별도 포슀트/μ„Ήμ…˜μ—μ„œ λ‹€λ£Ήλ‹ˆλ‹€. (sam_vit_b.pth λ“± 체크포인트 ν•„μš”)


마치며

이 κΈ€μ—μ„œλŠ” VLM 기반 Attention Map을 λ§Œλ“€μ–΄ VL-SAM의 μ „λ°˜μ μΈ 흐름을 κ²€μ¦ν–ˆμŠ΅λ‹ˆλ‹€.
ν™˜κ²½/λΉŒλ“œ 차이둜 μ–΄ν…μ…˜ 제곡 방식이 λ‹€λ₯Ό 수 μžˆμœΌλ‹ˆ, μœ„ μˆ˜λ™ 루프/근사 μ „λž΅μ„ 기반으둜 μ‹œμž‘ν•΄ λ³΄μ„Έμš”.
ν•„μš”ν•˜μ‹œλ©΄ λΉ„μ „ μ–΄ν…μ…˜ 후크 버전 λ˜λŠ” νžˆλ“  μœ μ‚¬λ„ 기반 κ²½λŸ‰ 버전도 첨뢀해 λ“œλ¦΄κ²Œμš”.

This post is licensed under CC BY 4.0 by the author.