๐ฅ๏ธ Image classification using ViT with Python - ํ์ด์ฌ์ผ๋ก ViT ๋ชจ๋ธ์ ํ์ฉ, ์ด๋ฏธ์ง ๋ถ๋ฅํ๊ธฐ
(English) Exploring Image Classification with ViT Model in Python
Hello everyone! ๐
In the previous post, we delved into the theory behind ViT based on the original paper! Today, we will actually download this ViT model and perform image classification in a Python environment!!
1. Importing ViT Model from torchvision! (The Simplest Way)
You can easily import the Vision Transformer (ViT) model through torchvision, a core library for image-related tasks in the PyTorch ecosystem.
What kind of package is torchvision that provides models?
torchvision is a package developed and maintained by the PyTorch team, providing commonly used datasets, image transformations (transforms), and pre-trained model architectures in the field of computer vision.
torchvision provides models for the following reasons:
- Convenience: It supports researchers and developers in easily utilizing models with verified performance without the hassle of implementing image-related deep learning models from scratch.
- Rapid Prototyping: Pre-trained models allow for quick experimentation with new ideas and development of prototypes.
- Saving Learning Resources: Using models pre-trained on large-scale datasets saves the time and computational resources required for direct training.
- Leveraging Learned Representations: Pre-trained models have already learned general image features, enabling good performance on specific tasks with less data (transfer learning).
Types and Features of ViT Models Provided by torchvision
torchvision provides various CNN-based models as well as ViT models. Currently (as of April 28, 2025), the main types and features of ViT models provided by torchvision are as follows:
Name | Patch Size | Model Name | Features |
---|---|---|---|
ViT-Base | 16x16 | vit_b_16 | Offers a balanced size and performance. |
ViT-Base | 32x32 | vit_b_32 | Larger patch size can reduce computation but may miss fine-grained features. |
ViT-Large | 16x16 | vit_l_16 | Has more layers and a larger hidden dimension than the Base model, aiming for higher performance. Requires more computational resources. |
ViT-Large | 32x32 | vit_l_32 | A Large model with a larger patch size. |
ViT-Huge | 14x14 | vit_h_14 | One of the largest ViT models, aiming for top-level performance but requires very significant computational resources. |
These models all come with pre-trained weights on the ImageNet dataset, allowing for immediate use in image classification tasks.
The letters โbโ, โlโ, and โhโ in the model names indicate the Base, Large, and Huge model sizes, respectively, and the number following indicates the image patch size.
A larger patch size means the model looks at the image in larger chunks, which can lead to faster processing but potentially lower accuracy.
2. Todayโs Image!! ๐ถ Letโs Start Classifying!
Today, we will use a cute dog image to see how the ViT model classifies it. The ViT model we will use today is pre-trained on the ImageNet dataset!
What is imagenet_classes?
imagenet_classes
is a list of 1000 image classes used in the ImageNet Large Scale Visual Recognition Challenge (ILSVRC). The pre-trained ViT models provided by torchvision are trained on this ImageNet dataset, so the modelโs output will be prediction probabilities for these 1000 classes. imagenet_classes
serves to map these numerical prediction results to human-readable class names (e.g., โgolden retrieverโ, โpoodleโ).
imagenet_classes.json: A JSON file containing imagenet_classes information.
Since torchvision itself does not directly include the ImageNet class name list, you need to prepare a separate JSON file containing this information. You can obtain the imagenet_classes.json
file in the following way:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import requests
import json
# Read JSON file directly from URL
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
response = requests.get(url)
response.raise_for_status() # Raise an error for bad status codes
# Load JSON data
imagenet_labels = response.json()
with open("imagenet_classes.json", "w") as f:
json.dump(imagenet_labels, f)
3. Letโs Begin the Code!!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import json
# 1. Load ViT model (ViT-Base, patch size 16)
vit_b_16 = models.vit_b_16(pretrained=True)
vit_b_16.eval() # Set the model to evaluation mode
# 2. Define image preprocessing
# Resize images to 256 and then center crop to 224.
# Normalize using the mean and standard deviation of the ImageNet dataset.
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 3. Load the dog image (replace with your image file path)
image_path = "dog.jpg"
try:
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0) # Add batch dimension
except FileNotFoundError:
print(f"Error: Image file '{image_path}' not found.")
exit()
# 4. Perform prediction
with torch.no_grad():
output = vit_b_16(input_tensor)
# 5. Post-process the prediction results and print the class names
try:
with open("imagenet_classes.json", "r") as f:
imagenet_classes = json.load(f)
_, predicted_idx = torch.sort(output, dim=1, descending=True)
top_k = 5
print(f"Top {top_k} prediction results:")
for i in range(top_k):
class_idx = predicted_idx[0, i].item()
confidence = torch.softmax(output, dim=1)[0, class_idx].item()
print(f"- {imagenet_classes[class_idx]}: {confidence:.4f}")
except FileNotFoundError:
print("Error: 'imagenet_classes.json' file not found. Please prepare the file in step 2.")
print("Predicted class indices:", predicted_idx[0, :5].tolist())
except Exception as e:
print(f"Error during prediction processing: {e}")
When you run the code above!!! You can see the Top 5 prediction results as below~!
1
2
3
4
5
6
Top 5 Prediction Results:
- Golden Retriever: 0.9126
- Labrador Retriever: 0.0104
- Kuvasz: 0.0032
- Airedale Terrier: 0.0014
- tennis ball: 0.0012
We can see that the Golden Retriever is predicted with the highest probability of 91.26%.
4. Getting and Running the Model Directly from Hugging Face! + Analysis (Less Simple, But Customizable)
This time, letโs try importing the model directly from the Hugging Face ViT model and proceed!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import json
# 1. Load ViT model (ViT-Base, patch size 16)
vit_b_16 = models.vit_b_16(pretrained=True)
vit_b_16.eval() # Set the model to evaluation mode
# 2. Define image preprocessing
# Resize images to 256 and then center crop to 224.
# Normalize using the mean and standard deviation of the ImageNet dataset.
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 3. Load the dog image (replace with your image file path)
image_path = "dog.jpg"
try:
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0) # Add batch dimension
except FileNotFoundError:
print(f"Error: Image file '{image_path}' not found.")
exit()
# 4. Perform prediction
with torch.no_grad():
output = vit_b_16(input_tensor)
# 5. Post-process the prediction results and print the class names
with open("imagenet_classes.json", "r") as f:
imagenet_classes = json.load(f)
_, predicted_idx = torch.sort(output, dim=1, descending=True)
top_k = 5
print(f"Top {top_k} results:")
for i in range(top_k):
class_idx = predicted_idx[0, i].item()
confidence = torch.softmax(output, dim=1)[0, class_idx].item()
print(f"- {imagenet_classes[class_idx]}: {confidence:.4f}")
Similarly, it was classified as number 207, Golden Retriever!!!
But! Letโs look at the differences from the existing torchvision and model customization here!
a. Image Preprocessing Method!!
Looking at the preprocessing part below, ViTFeatureExtractor
already knows the preprocessing method used when the model was trained, allowing you to perform image preprocessing simply without writing a complex transforms.Compose
process directly!
1
2
3
4
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
# 3. preprocess : no need to crop and resize
inputs = feature_extractor(images=image, return_tensors="pt")
b. Viewing the CLS Token!!
In the previous theoretical learning post, we learned that it consists of 196 patches + 1 CLS token, totaling 197 patches! We confirmed that the overall information of the image is contained in this first CLS token! You can see the CLS Token with the following code!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import ViTModel, ViTImageProcessor
import torch
from PIL import Image
# 1. ViTModel (Pure model without classification head)
model = ViTModel.from_pretrained('google/vit-base-patch16-224')
model.eval()
# Feature Extractor โ Updated to ViTImageProcessor
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
# 2. Load Image
image = Image.open("dog.jpg").convert('RGB')
inputs = processor(images=image, return_tensors="pt")
# 3. Model Inference
with torch.no_grad():
outputs = model(**inputs)
# 4. Extract CLS Token
last_hidden_state = outputs.last_hidden_state # (batch_size, num_tokens, hidden_dim)
cls_token = last_hidden_state[:, 0, :] # The 0th token is CLS
# 5. Print CLS Token
print("CLS token shape:", cls_token.shape) # torch.Size([1, 768])
print("CLS token values (first 5):", cls_token[0, :5])
If you run the code above, you can see the 768-dimensional CLS token as expected! Subsequent research uses this token for various other information!
1
2
CLS token shape: torch.Size([1, 768])
CLS token values (first 5): tensor([-0.5934, -0.3203, -0.0811, 0.3146, -0.7365])
c. ViTโs CAM!! Attention Rollout
In traditional CNN-based image classification, a CAM (Class Activation Map) was placed at the end of the model to visualize which parts became important!!!
CAM Theory Summary!!
CAM Practice!!
Our ViT model is different from CAM, so itโs difficult to proceed in the same way! However, you can visualize which of the remaining 196 patches the most important CLS package paid attention to using a method called Attention Rollout!
Looking at the structure!!
As shown below, Attention is the process by which [CLS] assigns weights to each patch like โyouโre importantโ or โyouโre not important,โ and visualizing these attentions is Attention Rollout!
1
2
3
4
5
[CLS] โ Patch_1 (Attention weight: 0.05)
[CLS] โ Patch_2 (Attention weight: 0.02)
[CLS] โ Patch_3 (Attention weight: 0.01)
...
[CLS] โ Patch_196 (Attention weight: 0.03)
In the end!! You can see a visualization of which patches were considered important as below!
- Red areas โ Patches that [CLS] paid much attention to.
- Blue areas โ Patches that [CLS] paid less attention to.
Looking at the code:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from transformers import ViTModel, ViTFeatureExtractor
import torch
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np
# 1. Load model and Feature Extractor
model = ViTModel.from_pretrained('google/vit-base-patch16-224', output_attentions=True)
model.eval()
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
# 2. Load Image
image = Image.open("dog.jpg").convert('RGB')
inputs = feature_extractor(images=image, return_tensors="pt")
# 3. Model Inference (output attention)
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions # list of (batch, heads, tokens, tokens)
# 4. Calculate Attention Rollout
def compute_rollout(attentions):
# Multiply attention matrices across layers
result = torch.eye(attentions[0].size(-1))
for attention in attentions:
attention_heads_fused = attention.mean(dim=1)[0] # (tokens, tokens)
attention_heads_fused += torch.eye(attention_heads_fused.size(-1))
attention_heads_fused /= attention_heads_fused.sum(dim=-1, keepdim=True)
result = torch.matmul(result, attention_heads_fused)
return result
rollout = compute_rollout(attentions)
# 5. Extract Attention from [CLS] token to image patches
mask = rollout[0, 1:].reshape(14, 14).detach().cpu().numpy()
# 6. Visualization
def show_mask_on_image(img, mask):
img = img.resize((224, 224))
mask = (mask - mask.min()) / (mask.max() - mask.min())
fig, ax = plt.subplots()
ax.imshow(img)
ax.imshow(mask, cmap='jet', alpha=0.5)
ax.axis('off')
plt.show()
show_mask_on_image(image, mask)
And the result is!!!??
Does it look right~?
5. ๐ก Conclusion: Simple and Fast ViT
How was it? You ran the code directly, and it was possible to execute the code easily and quickly!
Like this, ViT, which was theoretically significant! Since models trained on large-scale datasets can also be easily implemented in code, research based on Transformers has exploded in the field of computer vision ever since!
In the future, we will also explore and practice various Vision Transformer-based models such as DINO, DeiT, CLIP, Swin Transformer, etc.! ^^
Thank you!!! ๐๐ฅ
(ํ๊ตญ์ด) ํ์ด์ฌ์ผ๋ก ViT ๋ชจ๋ธ์ ํ์ฉ, ์ด๋ฏธ์ง ๋ถ๋ฅํ๋ณด๊ธฐ
์๋ ํ์ธ์! ๐
์ง๋ ํฌ์คํ
์์๋ ViT์ Paper๋ฅผ ๋ฐํ์ผ๋ก ์ด๋ก ์ ์์๋ณด์๋๋ฐ์!
์ค๋์ ์ค์ ์ด ViT๋ธ์ ๋ค์ด๋ฐ์ Python ํ๊ฒฝ์์ ์ด๋ฏธ์ง ๋ถ๋ฅ ์์
์ ์งํํด๋ณด๊ฒ ์ต๋๋ค!!
1. ViT ๋ชจ๋ธ!! torchvision ์์ ์ํฌํธ ํ๋ ๋ฐฉ์์ผ๋ก! (์ ์ผ ๊ฐ๋จ)
PyTorch ์ํ๊ณ์์ ์ด๋ฏธ์ง ๊ด๋ จ ์์ ์ ์ํ ํต์ฌ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ค ํ๋์ธ torchvision์ ํตํด Vision Transformer (ViT) ๋ชจ๋ธ์ ๊ฐํธํ๊ฒ ๋ถ๋ฌ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
torchvision ์ ๋ฌด์จ ํจํค์ง์ด๊ธธ๋ ๋ชจ๋ธ์ ์ ๊ณตํด์ฃผ๋?
torchvision์ PyTorch ํ์์ ๊ฐ๋ฐํ๊ณ ์ ์ง ๊ด๋ฆฌํ๋ ํจํค์ง๋ก, ์ปดํจํฐ ๋น์ ๋ถ์ผ์์ ์์ฃผ ์ฌ์ฉ๋๋ ๋ฐ์ดํฐ์ , ์ด๋ฏธ์ง ๋ณํ(transforms), ๊ทธ๋ฆฌ๊ณ ๋ฏธ๋ฆฌ ํ์ต๋(pre-trained) ๋ชจ๋ธ ์ํคํ ์ฒ๋ฅผ ์ ๊ณตํฉ๋๋ค.
torchvision์์ ์ ๊ณตํ๋ ViT ๋ชจ๋ธ ์ข ๋ฅ์ ๊ฐ ๋ชจ๋ธ์ ํน์ง
torchvision์ ๋ค์ํ CNN ๊ธฐ๋ฐ ๋ชจ๋ธ๋ฟ๋ง ์๋๋ผ ViT ๋ชจ๋ธ๋ ์ ๊ณตํฉ๋๋ค.
ํ์ฌ (2025๋
4์ ๊ธฐ์ค) torchvision์์ ์ ๊ณตํ๋ ์ฃผ์ ViT ๋ชจ๋ธ ์ข
๋ฅ์ ํน์ง์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
์ด๋ฆ | ํจ์น ์ฌ์ด์ฆ | ๋ชจ๋ธ๋ช | ํน์ง |
---|---|---|---|
ViT-Base | 16x16 | vit_b_16 | ๊ท ํ ์กํ ํฌ๊ธฐ์ ์ฑ๋ฅ์ ์ ๊ณตํฉ๋๋ค. |
ViT-Base | 32x32 | vit_b_32 | ๋ ํฐ ํจ์น ํฌ๊ธฐ๋ก ์ธํด ๊ณ์ฐ๋์ด ์ค์ด๋ค ์ ์์ง๋ง, ์ธ๋ฐํ ํน์ง์ ๋์น ์ ์์ต๋๋ค. |
ViT-Large | 16x16 | vit_l_16 | Base ๋ชจ๋ธ๋ณด๋ค ๋ ๋ง์ ๋ ์ด์ด์ ํฐ hidden dimension์ ๊ฐ์ ธ ๋ ๋์ ์ฑ๋ฅ์ ๋ชฉํ๋ก ํฉ๋๋ค. ๋ ๋ง์ ์ปดํจํ ์์์ ์๊ตฌํฉ๋๋ค. |
ViT-Large | 32x32 | vit_l_32 | Large ๋ชจ๋ธ์ ํฐ ํจ์น ํฌ๊ธฐ๋ฅผ ์ ์ฉํ ๋ชจ๋ธ์ ๋๋ค. |
ViT-Huge | 14x14 | vit_h_14 | ๊ฐ์ฅ ํฐ ViT ๋ชจ๋ธ ์ค ํ๋๋ก, ์ต๊ณ ์์ค์ ์ฑ๋ฅ์ ๋ชฉํ๋ก ํ์ง๋ง ๋งค์ฐ ๋ง์ ์ปดํจํ ์์์ ํ์๋ก ํฉ๋๋ค. |
์ด๋ฌํ ๋ชจ๋ธ๋ค์ ๋ชจ๋ ImageNet ๋ฐ์ดํฐ์
์ผ๋ก ์ฌ์ ํ์ต๋ ๊ฐ์ค์น์ ํจ๊ป ์ ๊ณต๋์ด,
์ด๋ฏธ์ง ๋ถ๋ฅ ์์
์ ๋ฐ๋ก ํ์ฉํ ์ ์์ต๋๋ค.
๋ชจ๋ธ ์ด๋ฆ์ b
, l
, h
๋ ๊ฐ๊ฐ Base, Large, Huge ๋ชจ๋ธ ํฌ๊ธฐ๋ฅผ ๋ํ๋ด๋ฉฐ,
๋ค์ ์ซ์๋ ์ด๋ฏธ์ง ํจ์น์ ํฌ๊ธฐ๋ฅผ ์๋ฏธํฉ๋๋ค.
ํจ์น ํฌ๊ธฐ๊ฐ ํด์๋ก ์ด๋ฏธ์ง๋ฅผ ํฌ๊ฒํฌ๊ฒ ๋ณด๋๊ฒ์ด๋ ์๋๋ ๋น ๋ฅด์ง๋ง ์ ํ๋๊ฐ ๋ฎ๊ฒ ์ง์!?
2. ์ค๋์ ์ด๋ฏธ์ง!! ๐ถ ๋ถ๋ฅ ์์!
์ค๋์ ๊ท์ฌ์ด ๊ฐ์์ง ์ด๋ฏธ์ง๋ฅผ ์ฌ์ฉํ์ฌ ViT ๋ชจ๋ธ์ด ์ด๋ป๊ฒ ์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฅํ๋์ง ํ์ธํด๋ณด๊ฒ ์ต๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ค๋์ ViT ๋ชจ๋ธ์ Imagenet์ ๋ฐ์ดํฐ์
์ผ๋ก ํ์ฉ๋ ๋ชจ๋ธ์ ํ์ฉํ ์์ ์
๋๋ค!!
imagenet_classes ์ด๋?
imagenet_classes
๋ ImageNet Large Scale Visual Recognition Challenge (ILSVRC)์์ ์ฌ์ฉ๋ 1000๊ฐ์ ์ด๋ฏธ์ง ํด๋์ค ๋ชฉ๋ก์
๋๋ค.
torchvision์์ ์ ๊ณตํ๋ ์ฌ์ ํ์ต๋ ViT ๋ชจ๋ธ์ ์ด ImageNet ๋ฐ์ดํฐ์
์ผ๋ก ํ์ต๋์๊ธฐ ๋๋ฌธ์, ๋ชจ๋ธ์ ์ถ๋ ฅ์ ์ด 1000๊ฐ์ ํด๋์ค์ ๋ํ ์์ธก ํ๋ฅ ๋ก ๋ํ๋ฉ๋๋ค. imagenet_classes
๋ ์ด๋ฌํ ์ซ์ ํํ์ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ์ฌ๋์ด ์ดํดํ ์ ์๋ ํด๋์ค ์ด๋ฆ(์: โgolden retrieverโ, โpoodleโ)์ผ๋ก ๋งคํํด์ฃผ๋ ์ญํ ์ ํฉ๋๋ค.
imagenet_classes.json : imagenet_classes ์ ๋ณด๋ฅผ ์ ์ฅํ json ์ ๋๋ค.
torchvision ์์ฒด์๋ ImageNet ํด๋์ค ์ด๋ฆ ๋ชฉ๋ก์ด ์ง์ ํฌํจ๋์ด ์์ง ์๊ธฐ์,
ํด๋น ์ ๋ณด๋ฅผ ๋ด๊ณ ์๋ JSON ํ์ผ์ ๋ณ๋๋ก ์ค๋นํด์ผ ํฉ๋๋ค. ๋ค์ ๋ฐฉ๋ฒ์ผ๋ก imagenet_classes.json
ํ์ผ์ ์ป์ ์ ์์ต๋๋ค.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import requests
import json
# URL์์ ์ง์ JSON ํ์ผ ์ฝ๊ธฐ
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
response = requests.get(url)
response.raise_for_status() # ์์ฒญ ์คํจ ์ ์๋ฌ ๋ฐ์
# JSON ๋ฐ์ดํฐ ๋ก๋
imagenet_labels = response.json()
with open("imagenet_classes.json", "r") as f:
imagenet_classes = json.load(f)
3. ์ฝ๋ ๋ณธ๊ฒฉ ์์!!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import json
# 1. ViT ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ (ViT-Base, ํจ์น ํฌ๊ธฐ 16 ์ฌ์ฉ)
vit_b_16 = models.vit_b_16(pretrained=True)
vit_b_16.eval() # ์ถ๋ก ๋ชจ๋๋ก ์ค์
# 2. ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ์ ์
# ์ด๋ฏธ์ง ํฌ๊ธฐ๊ฐ ๋ค ๋ค๋ฅด๋ 256์ผ๋ก ๋ฆฌ์ฌ์ด์ฆํ๊ณ 224๋ก ์ค์ ๋ถ๋ถ์ ํจ์นํฉ๋๋ค.
# ๊ทธ๋ฆฌ๊ณ ImageNet ๋ฐ์ดํฐ์
์ ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ก ์ ๊ทํํฉ๋๋ค.
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 3. ๊ฐ์์ง ์ด๋ฏธ์ง ๋ถ๋ฌ์ค๊ธฐ (๋ณธ์ธ์ ์ด๋ฏธ์ง ํ์ผ ๊ฒฝ๋ก๋ก ๋ณ๊ฒฝํด์ฃผ์ธ์)
image_path = "dog.jpg"
try:
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0) # ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ
except FileNotFoundError:
print(f"Error: ์ด๋ฏธ์ง ํ์ผ '{image_path}'์ ์ฐพ์ ์ ์์ต๋๋ค.")
exit()
# 4. ๋ชจ๋ธ์ ์
๋ ฅํ์ฌ ์์ธก ์ํ
with torch.no_grad():
output = vit_b_16(input_tensor)
# 5. ์์ธก ๊ฒฐ๊ณผ ํ์ฒ๋ฆฌ ๋ฐ ํด๋์ค ์ด๋ฆ ์ถ๋ ฅ
try:
with open("imagenet_classes.json", "r") as f:
imagenet_classes = json.load(f)
_, predicted_idx = torch.sort(output, dim=1, descending=True)
top_k = 5
print(f"Top {top_k} ์์ธก ๊ฒฐ๊ณผ:")
for i in range(top_k):
class_idx = predicted_idx[0, i].item()
confidence = torch.softmax(output, dim=1)[0, class_idx].item()
print(f"- {imagenet_classes[class_idx]}: {confidence:.4f}")
except FileNotFoundError:
print("Error: 'imagenet_classes.json' ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. 2๋จ๊ณ์์ ํ์ผ์ ์ค๋นํด์ฃผ์ธ์.")
print("์์ธก๋ ํด๋์ค ์ธ๋ฑ์ค:", predicted_idx[0, :5].tolist())
except Exception as e:
print(f"Error during prediction processing: {e}")
์ ์ฝ๋๋ฅผ ์คํํ๋ฉด!!!
์๋์ ๊ฐ์ด Top 5๊ฐ์ ์์ธก๊ฒฐ๊ณผ๋ฅผ ๋ณผ์ ์๋๋ฐ์~!
1
2
3
4
5
6
Top 5 ์์ธก ๊ฒฐ๊ณผ:
- Golden Retriever: 0.9126
- Labrador Retriever: 0.0104
- Kuvasz: 0.0032
- Airedale Terrier: 0.0014
- tennis ball: 0.0012
๊ณจ๋ ๋ฆฌํธ๋ฆฌ๋ฒ๋ฅผ 91.26%๋ก ๊ฐ์ฅ ๋์ ํ๋ฅ ๋ก ์์ธกํจ์ ๋ณผ์ ์์์ต๋๋ค
4. Huggingface ์์ ์ง์ ๋ชจ๋ธ์ ๋ฐ์์ ์คํํ๊ธฐ! + ๋ถ์ (๋ ๊ฐ๋จ, but ์ปค์คํฐ๋ง์ด์ง ๊ฐ๋ฅ)
์ด๋ฒ์๋ ์ง์ ํ๊น
ํ์ด์ค์ ViT ๋ชจ๋ธ๋ก๋ถํฐ ์ง์
๋ชจ๋ธ์ ์ํฌํธํ์ฌ ์งํํด๋ณด๊ฒ ์ต๋๋ค~!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import json
# 1. ViT ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ (ViT-Base, ํจ์น ํฌ๊ธฐ 16 ์ฌ์ฉ)
vit_b_16 = models.vit_b_16(pretrained=True)
vit_b_16.eval() # ์ถ๋ก ๋ชจ๋๋ก ์ค์
# 2. ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ์ ์
# ์ด๋ฏธ์ง ํฌ๊ธฐ๊ฐ ๋ค ๋ค๋ฅด๋ 256์ผ๋ก ๋ฆฌ์ฌ์ด์ฆํ๊ณ 224๋ก ์ค์ ๋ถ๋ถ์ ํจ์นํฉ๋๋ค.
# ๊ทธ๋ฆฌ๊ณ ImageNet ๋ฐ์ดํฐ์
์ ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ก ์ ๊ทํํฉ๋๋ค.
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 3. ๊ฐ์์ง ์ด๋ฏธ์ง ๋ถ๋ฌ์ค๊ธฐ (๋ณธ์ธ์ ์ด๋ฏธ์ง ํ์ผ ๊ฒฝ๋ก๋ก ๋ณ๊ฒฝํด์ฃผ์ธ์)
image_path = "dog.jpg"
try:
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0) # ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ
except FileNotFoundError:
print(f"Error: ์ด๋ฏธ์ง ํ์ผ '{image_path}'์ ์ฐพ์ ์ ์์ต๋๋ค.")
exit()
# 4. ๋ชจ๋ธ์ ์
๋ ฅํ์ฌ ์์ธก ์ํ
with torch.no_grad():
output = vit_b_16(input_tensor)
# 5. ์์ธก ๊ฒฐ๊ณผ ํ์ฒ๋ฆฌ ๋ฐ ํด๋์ค ์ด๋ฆ ์ถ๋ ฅ
with open("imagenet_classes.json", "r") as f:
imagenet_classes = json.load(f)
_, predicted_idx = torch.sort(output, dim=1, descending=True)
top_k = 5
print(f"Top {top_k} ์์ธก ๊ฒฐ๊ณผ:")
for i in range(top_k):
class_idx = predicted_idx[0, i].item()
confidence = torch.softmax(output, dim=1)[0, class_idx].item()
print(f"- {imagenet_classes[class_idx]}: {confidence:.4f}")
์ญ์ ๋ง์ฐฌ๊ฐ์ง๋ก~!! 207๋ฒ, ๊ณจ๋ ๋ฆฌํธ๋ฆฌ๋ฒ๋ก ๊ตฌ๋ถ๋์์ต๋๋ค!!!
๊ทธ๋ฐ๋ฐ! ์ฌ๊ธฐ์์ ๊ธฐ์กด torchvision๊ณผ ์ฐจ์ด & ๋ชจ๋ธ ์ปค์คํฐ๋ง์ด์ง ๋ฑ์ ์์๋ณด๊ฒ ์ต๋๋ค!!
a. ์ด๋ฏธ์ง์ ์ ์ฒ๋ฆฌ๋ฐฉ์!!
์๋์ ์ ์ฒ๋ฆฌ ๋ถ๋ถ์ ๋ณด๋ฉด, ViTFeatureExtractor๋ ํด๋น ๋ชจ๋ธ์ด ํ์ต๋ ๋ ์ฌ์ฉํ๋ ์ ์ฒ๋ฆฌ ๋ฐฉ์์ ๋ฏธ๋ฆฌ ์๊ณ ์์ด,
๋ณต์กํ transforms.Compose ๊ณผ์ ์ ์ง์ ์์ฑํ์ง ์๊ณ ๊ฐ๋จํ๊ฒ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ๋ฅผ ์ํํ ์ ์๊ฒ ํด์ค๋ต๋๋ค~!!
1
2
3
4
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
# 3. ์ ์ฒ๋ฆฌ : ์ง์ crop ๋ฐ resize ํ ํ์๊ฐ ์์ด์!
inputs = feature_extractor(images=image, return_tensors="pt")
b. CLS ํ ํฐ ๋ณด๊ธฐ!!
์ง๋ ์ด๋ก ํ์ต๊ธ์์ 196๊ฐ์ ํจ์น + 1๊ฐ์ CLS ํ ํฐ์ผ๋ก 197๊ฐ์ ํจ์น๋ก ๊ตฌ์ฑ๋จ์ ์์๋ณด์๋๋ฐ์~!
์ด ์ฒซ๋ฒ์จฐ์ CLS ํ ํฐ์ ์ด๋ฏธ์ง์ ์ ์ฒด์ ์ธ ์ ๋ณด๊ฐ ํฌํจ๋จ์ ํ์ธํ์์ต๋๋ค!!
์๋์ ๊ฐ์ ์ฝ๋๋ก CLS Token์ ๋ณผ ์ ์์ต๋๋ค!!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import ViTModel, ViTImageProcessor
import torch
from PIL import Image
# 1. ViTModel (Classification head ์๋ ์์ ๋ชจ๋ธ)
model = ViTModel.from_pretrained('google/vit-base-patch16-224')
model.eval()
# Feature Extractor โ ViTImageProcessor๋ก ์ต์ ํ
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
# 2. ์ด๋ฏธ์ง ๋ถ๋ฌ์ค๊ธฐ
image = Image.open("dog.jpg").convert('RGB')
inputs = processor(images=image, return_tensors="pt")
# 3. ๋ชจ๋ธ ์ถ๋ก
with torch.no_grad():
outputs = model(**inputs)
# 4. CLS ํ ํฐ ์ถ์ถ
last_hidden_state = outputs.last_hidden_state # (batch_size, num_tokens, hidden_dim)
cls_token = last_hidden_state[:, 0, :] # 0๋ฒ์งธ ํ ํฐ์ด CLS
# 5. CLS ํ ํฐ ์ถ๋ ฅ
print("CLS token shape:", cls_token.shape) # torch.Size([1, 768])
print("CLS token values (์ 5๊ฐ):", cls_token[0, :5])
์ ์ฝ๋๋ฅผ ์คํํด๋ณด๋ฉด, ์์ํ๋๋ก 768 ์ฐจ์์CLS ํ ํฐ์ ๋ณผ์ ์์ง์~~
์ดํ ์ฌ๋ฌ ์ฐ๊ตฌ๋ค์ ์ด ํ ํฐ์ ํ์ฉํด์ ๋ค๋ฅธ ์ ๋ณด๋ก ํ์ฉํ๊ธฐ๋ํฉ๋๋ค!
1
2
CLS token shape: torch.Size([1, 768])
CLS token values (์ 5๊ฐ): tensor([-0.5934, -0.3203, -0.0811, 0.3146, -0.7365])
c. ViT์ CAM!! Attention Rollout
๊ธฐ์กด CNN ๋ฐฉ์์ ์ด๋ฏธ์ง ๋ถ๋ฅ๋ ๋ชจ๋ธ์ ๋ง์ง๋ง๋จ์ CAM(Class Activation Map)์ ๋์ด์ ์ด๋ค ๋ถ๋ถ์ด ์ค์ํ๊ฒ ๋์๋์ง ์๊ฐํ ํ ์ ์์์ต๋๋ค!!!
CAM์ ์ด๋ก ์ ๋ฆฌ!!
CAM ์ค์ต!!
์ฐ๋ฆฌ์ ViT ๋ชจ๋ธ์ CAM๊ณผ๋ ๋ค๋ฅด๊ธฐ์ ๋์ผํ ๋ฐฉ์์ผ๋ก ์งํ์ ์ด๋ ต์ง๋ง!!
Attention Rollout ์ด๋ผ๋ ๋ฐฉ์์ผ๋ก ๊ฐ์ฅ ์ค์ํ CLS ํจํค์น๊ฐ ๋๋จธ์ง 196๊ฐ ํจ์น์ค ์ด๋๋ฅผ ์ค์ํ๊ฒ ๋ดค๋์ง!! ์๊ฐํํ ์ ์์ด์!!
๊ตฌ์กฐ๋ฅผ ๋ณด์๋ฉด!!
์๋์ ๊ฐ์ด [CLS]๊ฐ ๊ฐ ํจ์น์ ๋ํด โ๋ ์ค์ํดโ, โ๋ ๋ณ๋ก์ผโ ๊ฐ์ ์์ผ๋ก ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํ๋ ๊ฑธ Attention์ด๋ผ๊ณ ํ๊ณ , ๊ทธ ์ดํ ์ ๋ค์ ์๊ฐํํ๋๊ฒ์ด์ง์!
1
2
3
4
5
[CLS] โ Patch_1 (Attention weight: 0.05)
[CLS] โ Patch_2 (Attention weight: 0.02)
[CLS] โ Patch_3 (Attention weight: 0.01)
...
[CLS] โ Patch_196 (Attention weight: 0.03)
๊ฒฐ๊ตญ!! ์ด๋ค ํจ์น๊ฐ ์ค์ํ๊ฒ ๊ฐ์ฃผ๋์๋์ง ์๋์ ๊ฐ์ด ์๊ฐํ๊ฐ ๋์ง์~!!
- ๋นจ๊ฐ๊ฒ ๋ณด์ด๋ ์์ญ โ [CLS]๊ฐ ๋ง์ด ์ฃผ๋ชฉํ ํจ์น,
- ํ๋๊ฒ ๋ณด์ด๋ ์์ญ โ [CLS]๊ฐ ๋ ์ฃผ๋ชฉํ ํจ์น
์ฝ๋๋ก ๋ณด๋ฉด~~
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from transformers import ViTModel, ViTFeatureExtractor
import torch
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np
# 1. ๋ชจ๋ธ๊ณผ Feature Extractor ๋ถ๋ฌ์ค๊ธฐ
model = ViTModel.from_pretrained('google/vit-base-patch16-224', output_attentions=True)
model.eval()
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
# 2. ์ด๋ฏธ์ง ๋ถ๋ฌ์ค๊ธฐ
image = Image.open("dog.jpg").convert('RGB')
inputs = feature_extractor(images=image, return_tensors="pt")
# 3. ๋ชจ๋ธ ์ถ๋ก (attention ์ถ๋ ฅ)
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions # list of (batch, heads, tokens, tokens)
# 4. Attention Rollout ๊ณ์ฐ
def compute_rollout(attentions):
# Multiply attention matrices across layers
result = torch.eye(attentions[0].size(-1))
for attention in attentions:
attention_heads_fused = attention.mean(dim=1)[0] # (tokens, tokens)
attention_heads_fused += torch.eye(attention_heads_fused.size(-1))
attention_heads_fused /= attention_heads_fused.sum(dim=-1, keepdim=True)
result = torch.matmul(result, attention_heads_fused)
return result
rollout = compute_rollout(attentions)
# 5. [CLS] ํ ํฐ์์ ์ด๋ฏธ์ง ํจ์น๋ก ๊ฐ๋ Attention ์ถ์ถ
mask = rollout[0, 1:].reshape(14, 14).detach().cpu().numpy()
# 6. ์๊ฐํ
def show_mask_on_image(img, mask):
img = img.resize((224, 224))
mask = (mask - mask.min()) / (mask.max() - mask.min())
fig, ax = plt.subplots()
ax.imshow(img)
ax.imshow(mask, cmap='jet', alpha=0.5)
ax.axis('off')
plt.show()
show_mask_on_image(image, mask)
์ด๊ณ ๊ทธ ๊ฒฐ๊ณผ๋!!!??
์ ๋๋ค~! ๋ง๋๊ฒ ๊ฐ๋์~?
5. ๐ก ๊ฒฐ๋ก : ๊ฐ๋จํ๊ณ ๋น ๋ฅธ ViT
์ด๋ค๊ฐ์? ์ฝ๋๋ฅผ ์ง์ ์คํํด๋ณด์๋๋ฐ~!!
ํฐ ์ด๋ ค์์์ด, ๊ทธ๋ฆฌ๊ณ ๋น ๋ฅด๊ฒ ์ฝ๋๋ฅผ ์คํํ ์ ์์์ง์!?
์ด์ฒ๋ผ ์ด๋ก ์ ์ผ๋ก๋ ์ ์๋ฏธํ๋ ViT! ๋๊ท๋ชจ ๋ฐ์ดํฐ์ ์์ ํ์ต๋ ๋ชจ๋ธ์ด ์ฝ๋๋ก๋ ์ฝ๊ฒ ๊ตฌํ์ด ๊ฐ๋ฅํด์ ์ดํ๋ก ์ปดํจํฐ ๋น์ ๋ถ์ผ์์ Transformer ๊ธฐ๋ฐ ์ฐ๊ตฌ๊ฐ ํญ๋ฐ์ ์ผ๋ก ์ฆ๊ฐํ๊ฒ ๋์๋ค๊ณ ํฉ๋๋ค!!
์์ผ๋ก DINO, DeiT, CLIP, Swin Transformer ๋ฑ ๋ค์ํ ๋น์ Transformer ๊ธฐ๋ฐ์ ๋ชจ๋ธ๋ ์์๋ณด๋ฉฐ ์ค์ตํด๋ณผ ์ ์๋๋ก ํ๊ฒ ์ต๋๋ค~! ^^
๊ฐ์ฌํฉ๋๋ค!!! ๐๐ฅ