๐ฅ๏ธ DINO Python Experiment!! Super Impressive!! - DINO ํ์ด์ฌ ์ค์ต!! ์์ ์ ๊ธฐํด!!
(English) DINO Python Experiment!! So Cool!!
In the previous post, we learned the theory behind DINO!!
Today, letโs actually run the DINO model and see how it performs~!
- Starting with the conclusion today!!!
- It highlights important parts of the image with a tada~ using attention
- Isnโt that amazing!?
- Letโs explore how it works~!
1. What is timm?!!
In this post, weโll load the DINO model using
timm
.
Letโs first understand whattimm
(Torch Image Models) is!
- timm stands for Torch Image Models,
- A library that provides a wide array of tools and pretrained models for handling image tasks in PyTorch!!
Main features of timm:
- Offers various modern image models:
- Includes ResNet, EfficientNet, Vision Transformer (ViT), Swin Transformer, and more โ easily usable for image classification, detection, semantic segmentation, etc.
- Rich pretrained weights:
- Provides weights pretrained on large datasets such as ImageNet, JFT, BeiT, which makes transfer learning easier without the need for training from scratch
- Easy model creation:
- With
timm.create_model()
, you can create your desired model by name + conveniently load pretrained weights
- With
- Modular design:
- Easily access and modify components like backbone, pooling layer, classifier head โ highly flexible for building custom models or fine-tuning existing ones
- Various utility functions:
- Offers helpful tools for image transforms, dataset handling, optimizers, schedulers, etc.
- Active community:
- An open-source project actively maintained and continuously updated with new models and features
- Example of using timm
1
2
3
4
5
import timm
# USE DINO-ViT MODEL (pretrained)
model = timm.create_model('vit_base_patch16_224_dino', pretrained=True)
model.eval()
This loads the DINO model structure!!
(Weโll focus on hands-on today โ for architecture details, please check the theory post~)
2. Encoding Images with ViT-based DINO!! (Into Vectors)
The core idea of ViT is turning an image into a vector using Transformer techniques!!
Letโs start with this image of someone holding a fork!!
And now@!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import timm
import torchvision.transforms as T
# USE DINO-ViT MODEL (pretrained)
model = timm.create_model('vit_base_patch16_224_dino', pretrained=True)
model.eval()
# Load image
image_path = "hold_fork.jpg"
image = Image.open(image_path).convert('RGB')
# Image preprocess
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img_tensor = transform(image).unsqueeze(0)
# Model output with attention weights
with torch.no_grad():
outputs = model.forward_features(img_tensor) # Shape: (batch_size, 197, feature_dim)
np.shape(outputs)
Breaking down the code above:
- Load the model
- Load the image as RGB vector
- Preprocess and normalize the image to (224, 224)
- Feed it into DINO โ Get final output!!
And the output will be:
1
torch.Size([1, 197, 768])
So the output is a vector of shape 197 (1 CLS token + 196 patch tokens) ร 768 (DINOโs internal dimension)!!
Thatโs the end of the image encoding process!!!!
You can now analyze each patch token or the CLS token depending on your purpose~~!!
3. Visualizing the Encoded Output!! (Decoding)
The result is in vector form โ great for computers,
but hard for us to interpret, right?
Letโs decode it so we can actually see it!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import timm
import torchvision.transforms as T
# USE DINO-ViT MODEL (pretrained)
model = timm.create_model('vit_base_patch16_224_dino', pretrained=True)
model.eval()
# Load image
image_path = "hold_fork.jpg"
image = Image.open(image_path).convert('RGB')
# Image preprocess
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img_tensor = transform(image).unsqueeze(0)
# Model output with attention weights
with torch.no_grad():
# Get features including attention
outputs = model.forward_features(img_tensor) # Shape: (batch_size, 197, feature_dim)
# Extract patch tokens (excluding CLS)
patch_tokens = outputs[:, 1:, :] # (batch_size, 196, feature_dim)
# Attention map: compute importance using norm of patch tokens
attn_map = torch.norm(patch_tokens, dim=-1).reshape(14, 14) # (14x14)
# Normalize (scale to range 0โ1)
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
# Visualize full Attention Map
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
# Original image
ax[0].imshow(image)
ax[0].axis('off')
ax[0].set_title('Original Image')
# Attention Map
attn_map_resized = np.array(Image.fromarray(attn_map.numpy()).resize(image.size, resample=Image.Resampling.BILINEAR))
ax[1].imshow(image)
ax[1].imshow(attn_map_resized, cmap='jet', alpha=0.5) # Attention Map > heat map
ax[1].axis('off')
ax[1].set_title('DINO-ViT Attention Map')
plt.tight_layout()
plt.show()
This code builds upon the previous one by adding visualization!
The most important part is:
1
2
3
4
5
6
7
8
9
10
11
12
# Model output with attention weights
with torch.no_grad():
outputs = model.forward_features(img_tensor) # Shape: (batch_size, 197, feature_dim)
# Extract patch tokens (exclude CLS)
patch_tokens = outputs[:, 1:, :] # (batch_size, 196, feature_dim)
# Attention map: compute importance via patch token norm
attn_map = torch.norm(patch_tokens, dim=-1).reshape(14, 14) # (14x14)
# Normalize (scale to 0โ1)
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
This part excludes 1 out of the 197 patch outputs.
Understanding why is essential!!
Thatโs because we must exclude the CLS token!!
Once visualized, you get the same result as shown at the beginning of this post:
The DINO model, trained without any labels,
intelligently identifies and highlights important regions in red,
while marking less important ones in blue!
4. Conclusion!!
With DINO, itโs incredibly easy to turn images into vectors and visualize them!!
Building and training the model may have been tough,
but actually using it is super simple and impressive!!
Definitely something we should remember and leverage in future research~! ๐
Also, big thanks to timm
for making model usage so convenient!
It supports not just DINO, but many other models as well!
1
timm.list_models()
You can use this to see the long list of available models~
In my version, over 1,200 models are available!
I also saw resnet, swin, RegNet, EfficientNet โ looks like I need to study those too!!
(ํ๊ตญ์ด) DINO ํ์ด์ฌ ์ค์ต!! ์์ ์ ๊ธฐํด!!
์ง๋ ํฌ์คํ ์์ ๋ฐฐ์ ๋ DINO ์ด๋ก !!
์ค๋์ ๊ทธ DINO ๋ชจ๋ธ์ ์ค์ ๋ก ๊ฐ๋ํด๋ณด๊ณ ๊ทธ ๊ฒฐ๊ณผ๊ฐ ์ด๋ป๊ฒ ๋์ค๋์ง ์์๋ณด๊ฒ ์ต๋๋ค~!
- ์ค๋์ ๊ฒฐ๋ก ๋ถํฐ!!!
- ์ด๋ฏธ์ง์ ๋ํ์ฌ ์ค์ํ ๋ถ๋ถ์ ์ง์ง ํ๊ณ attention์ ์ค๋๋ค~
- ์ ๊ธฐํ์ง ์๋์!?
- ๊ทธ ๊ณผ์ ์ ์์๋ณด๊ฒ ์ต๋๋ค~!
1. timm ์ด๋?!!
์ด๋ฒ ํฌ์คํ ์์์ DINO ๋ชจ๋ธ์ timm ์ผ๋ก๋ถํฐ ๋ก๋ํ๊ณ ์ํฉ๋๋ค. ๊ทธ timm (Torch Image Models) ์ด ๋ฌด์์ธ์ง ์์๋ด ์๋ค!~!
- timm์ Torch Image Models์ ์ฝ์๋ก,
- PyTorch์์ ์ด๋ฏธ์ง ๋ชจ๋ธ์ ๋ค๋ฃจ๋ ๋ฐ ์ ์ฉํ ๋ค์ํ ๋๊ตฌ์ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ์ ์ ๊ณตํ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ!!
timm์ ์ฃผ์ ํน์ง:
- ๋ค์ํ ์ต์ ์ด๋ฏธ์ง ๋ชจ๋ธ ์ ๊ณต:
- ResNet, EfficientNet, Vision Transformer (ViT), Swin Transformer ๋ฑ ์ต์ CNN ๋ฐ Transformer ๊ธฐ๋ฐ์ ๋ค์ํ ์ด๋ฏธ์ง ๋ถ๋ฅ, ๊ฐ์ฒด ๊ฒ์ถ, ์๋ฏธ๋ก ์ ๋ถํ ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ์ฝ๊ฒ ์ฌ์ฉ ๊ฐ๋ฅ
- ํ๋ถํ ์ฌ์ ํ์ต๋ ๊ฐ์ค์น (Pretrained Weights):
- ImageNet, JFT, BeiT ๋ฑ ๋๊ท๋ชจ ๋ฐ์ดํฐ์ ์ผ๋ก ์ฌ์ ํ์ต๋ ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ์ ๊ณต, ์ฌ์ฉ์๊ฐ ์ง์ ๋ชจ๋ธ์ ํ์ต์ํค๋ ๋ถ๋ด์ ์ค์ด๊ณ ์ ์ด ํ์ต(Transfer Learning)์ ์ฉ์ดํ๊ฒ ํจ
- ๊ฐํธํ ๋ชจ๋ธ ์์ฑ:
timm.create_model()
ํจ์๋ฅผ ํตํด ๋ชจ๋ธ ์ด๋ฆ๋ง์ผ๋ก ์ํ๋ ๋ชจ๋ธ์ ์ฝ๊ฒ ์์ฑ ๊ฐ๋ฅ!! + ์ฌ์ ํ์ต๋ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๋ ์ต์ ๋ ๊ฐํธํ๊ฒ ์ ๊ณต
- ๋ชจ๋ํ๋ ์ค๊ณ:
- ๋ชจ๋ธ์ ๊ฐ ๊ตฌ์ฑ ์์ (๋ฐฑ๋ณธ, ํ๋ง ๋ ์ด์ด, ๋ถ๋ฅ ํค๋ ๋ฑ)๋ฅผ ์ฝ๊ฒ ์ ๊ทผํ๊ณ ์์ ํ ์ ์๋๋ก ์ค๊ณ๋์ด, ์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ์ ๊ตฌ์ถํ๊ฑฐ๋ ๊ธฐ์กด ๋ชจ๋ธ์ fine-tuningํ๋ ๋ฐ ์ ์ฐ์ฑ ์ ๊ณต
- ๋ค์ํ ์ ํธ๋ฆฌํฐ ํจ์:
- ์ด๋ฏธ์ง ๋ณํ(transform), ๋ฐ์ดํฐ์ ์ฒ๋ฆฌ, ์ต์ ํ๊ธฐ(optimizer), ์ค์ผ์ค๋ฌ(scheduler) ๋ฑ ์ด๋ฏธ์ง ๋ชจ๋ธ ํ์ต ๋ฐ ํ๊ฐ์ ํ์ํ ๋ค์ํ ์ ํธ๋ฆฌํฐ ํจ์ ์ ๊ณต
- ํ๋ฐํ ์ปค๋ฎค๋ํฐ:
- ์คํ ์์ค ํ๋ก์ ํธ๋ก ํ๋ฐํ ์ปค๋ฎค๋ํฐ ์ง์์ ๋ฐ์ผ๋ฉฐ ์ง์์ ์ผ๋ก ์๋ก์ด ๋ชจ๋ธ๊ณผ ๊ธฐ๋ฅ ์ง์ ์ถ๊ฐ
- ์์ผ๋ก ์ฌ์ฉํ timm ์์
1
2
3
4
5
import timm
# USE DINO-ViT MODEL (pretrained)
model = timm.create_model('vit_base_patch16_224_dino', pretrained=True)
model.eval()
์์ ๋ชจ๋ธ ๋ก๋๋ฅผ ํตํ์ฌ dino ๋ชจ๋ธ์ ๊ตฌ์กฐ๋ฅผ ๋ณผ์ ์์ง์~~
(์ค๋์ ์ค์ต์ผ๋ก ๊ตฌ์กฐ์ ๋ํ ์์ธํ ๋ด์ฉ์ ์ด๋ก ํฌ์คํ
์์ ํ์ธํด์ฃผ์ธ์~~)
2. ViT ์ธ DINO๋ก ์ด๋ฏธ์ง ์ธ์ฝํ !! (๋ฒกํฐ๋ก ๋ง๋ค๊ธฐ)
ViT์ ๊ธฐ๋ณธ ๊ฐ๋ ์ ์ด๋ฏธ์ง๋ฅผ Transformer ๋ฐฉ์์ ํตํด ๋ฒกํฐ๋ก ๋ง๋๋๊ฒ!!
์ฐ์ ์์ ๊ฐ์ด ํฌํฌ๋ฅผ ์ฅ๊ณ ์๋ ์ด๋ฏธ์ง๋ฅผ ์ค๋นํด๋ณด์์ต๋๋ค!! ๊ทธ๋ฆฌ๊ณ @!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import timm
import torchvision.transforms as T
# USE DINO-ViT MODEL (pretrained)
model = timm.create_model('vit_base_patch16_224_dino', pretrained=True)
model.eval()
# Load image
image_path = "hold_fork.jpg"
image = Image.open(image_path).convert('RGB')
# Image preprocess
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img_tensor = transform(image).unsqueeze(0)
# Model output with attention weights
with torch.no_grad():
# Get features including attention
outputs = model.forward_features(img_tensor) # Shape: (batch_size, 197, feature_dim)
np.shape(outputs)
์์ ์ฝ๋๋ฅผ ๊ฐ๋จํ๊ฒ ๋ถ์ํด๋ณด๋ฉด,
- ๋ชจ๋ธ์ ๋ก๋ํ๊ณ
- ์ด๋ฏธ์ง๋ฅผ RGB ๋ฒกํฐ ๊ฐ์ผ๋ก ๋ก๋ํ๊ณ !
- DINO์ ๋ฃ์์ ์๋๋ก ๊ตฌ์กฐ๋ฅผ ๋ฐ๊ฟ์ฃผ๊ณ ! - (224,224) ์ฌ์ด์ฆ์ ์ ๊ทํ!
- ๋ชจ๋ธ์ ๋ฃ์ด์!!! ์ต์ข output ๋ง๋ค๊ธฐ!!
๊ทธ๋ผ ๊ทธ output์!!
1
torch.Size([1, 197, 768])
๋ก์, 197 (1๊ฐ์ CLS + 196๊ฐ์ ํจ์น) X 768(DINO ์์ฒด์ ์ฐจ์ ) ์ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง ๋ฒกํฐ๋ก ๋์ค๊ฒ ๋ฉ๋๋ค!!
์ด๊ฒ ๋ฐ๋ก ์ด๋ฏธ์ง ์ธ์ฝ๋ฉ์ ๋!!!!
์ดํ ์ด ๋ฒกํฐ์ ๊ฐ๊ฐ์ ํจ์น ํน์ CLS ๊ฐ์ผ๋ก ๋ถ์์ ์งํํ ์ ์์ง์~~!!
3. ์ธ์ฝ๋ฉ๊ฒฐ๊ณผ๋ฌผ(๋ฒกํฐ)์ ์๊ฐํ!!(๋์ฝ๋ฉ)
๋ฒกํฐ๋ก๋ง ๊ฒฐ๊ณผ๊ฐ ๋์ค๋ฉด, ์ปดํจํฐ๋ ์ดํดํ ์ ์์ง๋ง,
์ฐ๋ฆฌ๋ ์ดํดํ๊ธฐ๊ฐ ํ๋ค์ฃ ?
๋์ฝ๋ฉ ํ์ฌ ์ฐ๋ฆฌ๊ฐ ๋ณผ์ ์๋๋ก ๋ง๋ค์ด ๋ด ์๋ค!!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import timm
import torchvision.transforms as T
# USE DINO-ViT MODEL (pretrained)
model = timm.create_model('vit_base_patch16_224_dino', pretrained=True)
model.eval()
# Load image
image_path = "hold_fork.jpg"
image = Image.open(image_path).convert('RGB')
# Image preprocess
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img_tensor = transform(image).unsqueeze(0)
# Model output with attention weights
with torch.no_grad():
# Get features including attention
outputs = model.forward_features(img_tensor) # Shape: (batch_size, 197, feature_dim)
# Extract patch tokens (CLS ์ ์ธ)
patch_tokens = outputs[:, 1:, :] # (batch_size, 196, feature_dim)
# Attention map: ํจ์น ํ ํฐ์ ๋
ธ๋ฆ(norm)์ ์ฌ์ฉํด ์ค์๋ ๊ณ์ฐ
attn_map = torch.norm(patch_tokens, dim=-1).reshape(14, 14) # (14x14)
# ์ ๊ทํ (0~1 ๋ฒ์๋ก ์ค์ผ์ผ๋ง)
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
# Attention Map ์ ์ฒด ์๊ฐํ
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
# Original image
ax[0].imshow(image)
ax[0].axis('off')
ax[0].set_title('Original Image')
# Attention Map
attn_map_resized = np.array(Image.fromarray(attn_map.numpy()).resize(image.size, resample=Image.Resampling.BILINEAR))
ax[1].imshow(image)
ax[1].imshow(attn_map_resized, cmap='jet', alpha=0.5) # Attention Map > heat map
ax[1].axis('off')
ax[1].set_title('DINO-ViT Attention Map')
plt.tight_layout()
plt.show()
์ด๋ฒ ์ฝ๋๋, ์์ ์ฝ๋์ ์ด์ด ์๊ฐํ ๋ถ๋ถ์ด ์ถ๊ฐ๋์์ต๋๋ค~! ์ฌ๊ธฐ์ ์ค์ํ๊ฒ์!!!
1
2
3
4
5
6
7
8
9
10
11
12
# Model output with attention weights
with torch.no_grad():
# Get features including attention
outputs = model.forward_features(img_tensor) # Shape: (batch_size, 197, feature_dim)
# Extract patch tokens (exclude CLS )
patch_tokens = outputs[:, 1:, :] # (batch_size, 196, feature_dim)
# Attention map: ํจ์น ํ ํฐ์ ๋
ธ๋ฆ(norm)์ ์ฌ์ฉํด ์ค์๋ ๊ณ์ฐ
attn_map = torch.norm(patch_tokens, dim=-1).reshape(14, 14) # (14x14)
# ์ ๊ทํ (0~1 ๋ฒ์๋ก ์ค์ผ์ผ๋ง)
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
์ ๋ถ๋ถ์ผ๋ก์ output์ 197๊ฐ patch ์์ 1๊ฐ๋ฅผ ์ ์ธํ๊ฒ ๋์ง์~!
์์ธ์ง ์ดํดํ๋๊ฒ์ ํ์์
๋๋ค!!
๋ฐ๋ก CLS๋ฅผ ์ ์ธํด์ผํ๊ธฐ ๋๋ฌธ์ด์์!!
์ด๋ ๊ฒ ์๊ฐํํด๋ณด๋ฉด,
ํฌ์คํ
์ ์ฒ์์์ ๋ณด์๋, ์๋์ ๊ฐ์ ๊ฒฐ๊ณผ๋ฌผ์ ๋ณผ์ ์์ต๋๋ค~
๋ณ๋์ ๋ผ๋ฒจ์์ด ํ์ต๋ DINO ๋ชจ๋ธ์ด,
์ด๋ฏธ์ง์ ์ค์ํ ๋ถ๋ถ์ ํ๋จํ์ฌ ๋ถ์ ์์ผ๋ก,
๋ ์ค์ํ ๋ถ๋ถ์ ํธ๋ฅธ์์ผ๋ก ์๊ฐํํ์์ต๋๋ค!!
4. ๊ฒฐ๋ก !!
DINO๋ก ์ด๋ฏธ์ง๋ฅผ ์ฝ๊ฒ ๋ฒกํฐ๋ก ๋ง๋ค๊ณ ์๊ฐํํ ์ ์๋ค์!!
๋ชจ๋ธ์ ์ฐ๊ตฌํ๊ณ ๋ง๋๋๋ฐ๋ ์ฝ์ง ์์๊ฒ ์ง๋ง ์์ฉ์ด ์ ๋ง ์ฝ๋ค๋๊ฒ์ ๋๋ผ๊ณ !!
์ด๋ฐ ๋ชจ๋ธ์ ๋ค๋ฅธ ์ฐ๊ตฌ์ ํ์ฉํ ์ ์๋๋ก ์ ๊ธฐ์ตํด๋์ด์ผ๊ฒ ์ต๋๋ค~!^^
๋ํ ๋ชจ๋ธ์ ๊ฐ๋จํ๊ฒ ์ธ์ ์๋๋ก ํด์ค timm์ ์ ๋ง ๊ฐ์ฌํ๋ค์~! ๋จ์ํ DINO ๋ฟ๋ง ์๋๋ผ ๋ค์ํ ๋ชจ๋ธ์ ์ธ ์ ์๋๋ฐ,
1
timm.list_models()
์ ํตํด ๊ฐ๋ฅํ ์๋ง์ ๋ชจ๋ธ๋ค์ ํ์ธ ๊ฐ๋ฅํฉ๋๋ค~!
์ ๋ฒ์ ผ์์๋ 1,200 ๊ฐ์ ๋ชจ๋ธ์ ํ์ฉ ๊ฐ๋ฅํ๋ค์!
*๊ทธ ์ธ์๋ resnet, swin, RegNet, EfficientNet ๋ฑ์ด ๋ณด์ด๋๋ฐ ์ด๋ฐ ๋ชจ๋ธ๋ค๋ ๊ณต๋ถํด๋ด์ผ ๊ฒ ์ด์!! *