π VL-SAM Hands-on: VL-SAM μ μ€μ΅ν΄λ³΄μ!
κ°μ
μ΄ ν¬μ€νΈμμλ VL-SAMμ ν΅μ¬μΈ βκ°μ²΄ μ΄λ¦ β μμΉ ννΈ(Attention Map) μμ± β SAM ν¬μΈνΈ ν둬ννΈβ μ€, Attention Map μμ±(VLM μΈ‘) μ€μ΅μ λ€λ£Ήλλ€.
(λ°λ³΅ iteration μμ΄, λ¨μΌ ν¨μ€ μ±κ²©μ λ°λͺ¨)
1) [Object Recognition] β Attention Map Generation (VLM)
ν΅μ¬ μμ΄λμ΄: SAMμ λ£μ Object Promptλ₯Ό VLMμ μ΄ν μ νλ¦μΌλ‘λΆν° λ§λ€μ!
- a. μ΄λ―Έμ§ κΈ°λ°μΌλ‘ VLMμκ² βμ΄λ―Έμ§ μ λͺ¨λ κ°μ²΄λ₯Ό λμ΄ν΄βλΌκ³ μμ² β μλ΅μμ Tag2Text μ μ¬ λ°©μμΌλ‘ κ°μ²΄ 리μ€νΈλ₯Ό ν보
- b. ν ν° μμ± κ³Όμ μμ λͺ¨λ λ μ΄μ΄/ν€λμ Q, Kλ₯Ό μ μ₯
- c. μΆμΆν κ°μ²΄ ν ν°μ λν΄ Q Γ Kα΅ β causal mask μ μ© β Softmax νμ€ν β similarity matrix S
- d. λ μ΄μ΄/ν€λ λ³ κ°μ€μΉ W μ°μΆ
- μμ:
W = Mean(Max(S, dim=1), dim=0)
- μμ:
- e. κ°μ€μΉ Wλ‘ λ³΄μ λ Sβ² κ³μ°
- f. λ μ΄μ΄λ³ Sβ²λ₯Ό μ’ ν©νμ¬ attention flow μ°μΆ
- g. Auto-Regressive VLMμ νΉμ±(μ’μλ¨μΌλ‘ collapse κ²½ν₯)μ μ€μ΄κΈ° μν΄ Regularized attention flow column μ¬μ©
- Finally, μ΅μ’ Attention Map μμ±!
2) μ€μ΅ μ½λ (Attention Map μμ±)
νκ²½ κ°μ
Qwen/Qwen2.5-VL-3B-Instruct
- 4-bit μμν(
bitsandbytes
) μ¬μ© κ°λ₯ μ λ©λͺ¨λ¦¬ μ μ½- μλ μ½λλ μ΄ν μ λ§΅ μμ±κΉμ§ (SAM μ°λμ λ³λ λ¨κ³μμ μ§ν κ°λ₯)
μ£Όμ
- λΈλ‘κ·Έμ λ§μΆ° μ½λ νμ€λ
$$$
λ‘ κ°μμ΅λλ€. λ³΅λΆ ν λ°λ‘ μ€ννμΈμ.- κ²½λ‘(
IMAGE_PATH
)λ λ‘컬 νκ²½μ λ§κ² μμ νμΈμ.
$$$ import os import re import cv2 import torch import numpy as np import matplotlib.pyplot as plt from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq from transformers import BitsAndBytesConfig
βββββββββ
Config
βββββββββ
QWEN_MODEL = βQwen/Qwen2.5-VL-3B-Instructβ # VRAMμ λ§μΆ° μ‘°μ κ°λ₯ (3B κΆμ₯) IMAGE_PATH = β/home/bongo/porter_notebook/research/dog.jpgβ # μ€μ΅μ© μ΄λ―Έμ§(κ°μμ§ 1λ§λ¦¬) TEXT_QUERY = βdogβ DEVICE = βcudaβ if torch.cuda.is_available() else βcpuβ
4bit μμν(κ°λ₯ μ λ©λͺ¨λ¦¬ μ μ½) β bnb λ―Έμ€μΉλ©΄ μ΄ μ€ μμ
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
print(fβ[Info] Device: {DEVICE}β) image_pil = Image.open(IMAGE_PATH).convert(βRGBβ)
print(fβ[Info] Loading {QWEN_MODEL}β¦β) processor = AutoProcessor.from_pretrained(QWEN_MODEL, trust_remote_code=True) model = AutoModelForVision2Seq.from_pretrained( QWEN_MODEL, device_map=βautoβ, torch_dtype=βautoβ, quantization_config=bnb, # bitsandbytes λ―Έμ€μΉ μ μ κ±° trust_remote_code=True, ) print(β[Info] Model loaded.β)
βββββββββββββββ
Step (a): μ΄λ―Έμ§μμ κ°μ²΄ 리μ€νΈ μμ± (Tag2Text μ μ¬)
βββββββββββββββ
print(β\n[Step a] Generating object list from the imageβ¦β)
messages_for_obj_list = [ { βroleβ: βuserβ, βcontentβ: [ {βtypeβ: βimageβ}, {βtypeβ: βtextβ, βtextβ: βPlease analyze the image and list all the objects present.β} ] } ]
text_template = processor.apply_chat_template(messages_for_obj_list, tokenize=False, add_generation_prompt=True) inputs_for_obj_list = processor(text=[text_template], images=[image_pil], return_tensors=βptβ).to(DEVICE)
with torch.no_grad(): output = model.generate( **inputs_for_obj_list, max_new_tokens=256, # κ³Όλν κΈΈμ΄λ λ©λͺ¨λ¦¬β do_sample=False, temperature=0.0 )
generated_text = processor.batch_decode(output, skip_special_tokens=True)[0]
κ°λ¨ νμ: βassistant\nβ μ΄νλ₯Ό κΈμ΄ μΌν/κ°νμΌλ‘ split
obj_list_raw = generated_text.split(βassistant\nβ)[-1].strip() object_list = [obj.strip().lower() for obj in re.split(rβ[,\n]β, obj_list_raw) if obj.strip()]
print(fβ - Detected objects (raw): {object_list}β)
λ°λͺ¨μ©μΌλ‘ λμ κ°μ²΄ 1κ°λ§ μ¬μ©
if len(object_list) == 0: object_list = [TEXT_QUERY] target_object = object_list[0] print(fβ - Target object: {target_object}β)
ββββββββββββββββββββ
Final: μλ μμ± λ£¨νλ₯Ό ν΅ν΄ λμ½λ μ΄ν μ μ μμ§ β λ§΅ κ·Όμ¬
(μ£Όμ) μ€μ λ‘λ Vision ν¨μΉ μ΄ν μ μ΄ λ κ³΅κ° μλ―Έκ° λͺ ννμ§λ§,
λΉλ/λ©λͺ¨λ¦¬ μ μ½ μ λμ½λ μ΄ν μ κ·Όμ¬λ₯Ό μ¬μ©.
ββββββββββββββββββββ
print(fβ\n[Final Step] Generating attention map for β{target_object}β via a MANUAL generation loopβ¦β)
prompt_for_attention = βPlease analyze the image and list all the objects present.β messages_for_attention = [ {βroleβ: βuserβ, βcontentβ: [{βtypeβ: βimageβ}, {βtypeβ: βtextβ, βtextβ: prompt_for_attention}]} ] text_template = processor.apply_chat_template(messages_for_attention, tokenize=False, add_generation_prompt=True) inputs = processor(text=[text_template], images=[image_pil], return_tensors=βptβ).to(DEVICE) input_ids = inputs.input_ids
λμ ν ν° ID (κ°λ¨ν 곡백+ν ν° μΈμ½λ©)
token_ids = processor.tokenizer.encode(fβ {target_object}β, add_special_tokens=False) target_token_id = token_ids[0] if len(token_ids) > 0 else None
found_target_attention = None generated_token_idx = -1 past_key_values = None
λ©λͺ¨λ¦¬ μμ μ μν΄ ν ν° κΈΈμ΄ μ ν (μ: μ΅λ 30)
for step in range(30): with torch.no_grad(): outputs = model( input_ids=input_ids, past_key_values=past_key_values, use_cache=True, output_attentions=True, # λμ½λ μ΄ν μ μμ² return_dict=True )
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
# λμ ν ν°μ΄ μμ±λλ©΄ μ΄ν
μ
μ μ₯
if target_token_id is not None and next_token.item() == target_token_id and found_target_attention is None:
print(f" - Found '{target_object}' token at generation step {step}.")
# outputs.attentions: Tuple[ num_layers x (B, num_heads, tgt_len, src_len) ]
found_target_attention = outputs.attentions
generated_token_idx = input_ids.shape[1] # νμ¬ λ§μ§λ§ μμΉ
# λ€μ λ¨κ³ μ€λΉ
input_ids = torch.cat([input_ids, next_token], dim=-1)
past_key_values = outputs.past_key_values
# EOSλ©΄ μ€λ¨
if next_token.item() == processor.tokenizer.eos_token_id:
print(" - Reached EOS.")
break
print(β - Manual generation complete.β)
βββββββββββββββ
Attention Map μ§κ³ (κ·Όμ¬) λ° μκ°ν
βββββββββββββββ
if found_target_attention: print(fβ - Aggregating attention from {len(found_target_attention)} layersβ¦β) # λΉμ ν¨μΉ 그리λ ν¬κΈ° μΆμ κ° (κ°λ¨ν λ°λͺ¨μ©) image_size = 448 patch_size = getattr(model.config.vision_config, βpatch_sizeβ, 14) grid_size = max(1, image_size // patch_size) num_patches = grid_size * grid_size
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
all_layer_attentions = []
for layer_attention in found_target_attention:
# layer_attention: (B, num_heads, tgt_len, src_len)
# μμ±λ μ§μ ν ν°μ queryκ° λ°λΌλ³Έ src λΆν¬(row) μ·¨λ
token_row = layer_attention[0, :, generated_token_idx - 1, :] # (num_heads, src_len)
# μμͺ½ ν¨μΉ ν ν°μ ν΄λΉνλ€κ³ κ°μ νκ³ num_patchesλ§ μ·¨λ (κ·Όμ¬)
image_patch_attention = token_row[:, :num_patches] # (num_heads, num_patches)
layer_avg = image_patch_attention.mean(dim=0) # (num_patches,)
all_layer_attentions.append(layer_avg)
final_avg = torch.stack(all_layer_attentions, dim=0).mean(dim=0) # (num_patches,)
attention_map = final_avg.reshape(grid_size, grid_size).cpu().numpy()
# μλ³Έ ν΄μλλ‘ μ
μν
resized_map = cv2.resize(attention_map, (image_pil.width, image_pil.height), interpolation=cv2.INTER_CUBIC)
# μκ°ν
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1); plt.title("Original Image"); plt.axis('off'); plt.imshow(image_pil)
plt.subplot(1, 2, 2); plt.title(f"Aggregated Attention for '{target_object}'"); plt.axis('off')
plt.imshow(image_pil); plt.imshow(resized_map, cmap='jet', alpha=0.5)
plt.tight_layout()
plt.savefig("attention_map_result.png", dpi=200)
print("[OK] Saved attention_map_result.png") else:
print(f" - Could not find or generate the token for '{target_object}' within the generation limit.") $$$
3) μ€ν ν & νΈλ¬λΈμν
- λ©λͺ¨λ¦¬(OOM) ννΌ
- κ°λ₯νλ©΄ 3B 체ν¬ν¬μΈνΈ μ¬μ©
- μ λ ₯ μ΄λ―Έμ§λ₯Ό λ¨Όμ κΈ΄ λ³ 384~448λ‘ μΆμ ν ν¬μ
max_new_tokens
λ₯Ό 48~128 λ²μλ‘ μ ν
- bitsandbytes(4bit) λ―Έμ€μΉ/νΈν λ¬Έμ
quantization_config=bnb
μ€μ μμ νκ³ μ€ν (λμ VRAM μ¬μ νμ)- λλ
pip install -U bitsandbytes
(NVIDIA CUDA νκ²½μμλ§)
- μ΄ν
μ
μ΄ NoneμΌλ‘ μ€λ κ²½μ°
- λΉλμ λ°λΌ
generate()
μμλ μ΄ν μ μ λ°ννμ§ μμ - μ μ½λλ μλ μμ± λ£¨ν +
output_attentions=True
λ‘ λμ½λ μ΄ν μ μ μμ§ (곡κ°μ κ·Όμ¬) - λ μ νν κ³΅κ° ννΈλ§΅μ΄ νμνλ©΄ λΉμ νμ μ΄ν μ ννΉ λλ νλ μ μ¬λ(cosine) λ°©μμΌλ‘ λ체 κ°λ₯
- λΉλμ λ°λΌ
4) λ€μ λ¨κ³ (SAM μ°λ)
- μμμ μ»μ Attention Mapμμ Positive/Negative ν¬μΈνΈλ₯Ό μνλ§
segment-anything
μSamPredictor.predict(point_coords, point_labels)
μ μ λ¬- κ²°κ³Ό segmentation mask μκ°ν(overlay)
SAM μ°λ μμλ λ³λ ν¬μ€νΈ/μΉμ μμ λ€λ£Ήλλ€. (
sam_vit_b.pth
λ± μ²΄ν¬ν¬μΈνΈ νμ)
λ§μΉλ©°
μ΄ κΈμμλ VLM κΈ°λ° Attention Mapμ λ§λ€μ΄ VL-SAMμ μ λ°μ μΈ νλ¦μ κ²μ¦νμ΅λλ€.
νκ²½/λΉλ μ°¨μ΄λ‘ μ΄ν
μ
μ 곡 λ°©μμ΄ λ€λ₯Ό μ μμΌλ, μ μλ 루ν/κ·Όμ¬ μ λ΅μ κΈ°λ°μΌλ‘ μμν΄ λ³΄μΈμ.
νμνμλ©΄ λΉμ μ΄ν
μ
νν¬ λ²μ λλ νλ μ μ¬λ κΈ°λ° κ²½λ λ²μ λ 첨λΆν΄ λ릴κ²μ.