Posts 사전훈련모델과 전이 학습
Post
Cancel

사전훈련모델과 전이 학습

1. 사전훈련모델


1.1 사전훈련모델

  • 좋은 성능을 보이는 모델들은 아주 무겁고, 훈련시간도 오래걸린다
    • ResNet-50은 8개의 P100 GPU를 29시간 학습해서 얻은 모델
    • FaceBook의 인공지능팀이 이를 1시간으로 줄였는데, 대신 GPU를 256개 사용함
  • 보통 많은 연구자들이 자신이 학습한 모델을 공개함
  • 공개된 모델들을 보통은 사전훈련된모델이라고 이야기하고, 이를 그대로 사용하거나 혹은 전이학습에 이용함


1.2 Tensorflow Hub

  • Tensorflow Hub는 재사용 가능한 머신러닝 모듈 라이브러리
  • 설치 : pip install tensorflow_hub


1.3 Mobile-net V2


1.4 Mobile-net v2(사전학습된모델) 불러오기

1
2
3
4
5
6
7
8
9
10
import tensorflow as tf
import numpy as np
import tensorflow_hub as hub

url = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2'
model = tf.keras.Sequential([
    hub.KerasLayer(handle = url, input_shape = (224, 224, 3), trainable = False)
])

model.summary()
1
2
3
4
5
6
7
8
9
10
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 1001)              3540265   
=================================================================
Total params: 3,540,265
Trainable params: 0
Non-trainable params: 3,540,265
_________________________________________________________________
  • 파라미터의 수가 아주 작음


1.5 테스트용 이미지 : ImageNetV2

  • ImageNet의 일부 데이터를 모아놓은 ImageNetv2
  • 아마존 메커니컬 터크에서 배포함
  • 사람의 수작업이 많이 필요한 이미지 라벨링 등을 위해 비교적 저렴한 가격으로 라벨링된 이미지를 제공하는 플랫폼


1.5 ImageNetV2 다운로드

1
2
3
4
5
6
7
import pathlib
import os

im_url = 'https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-topimages.tar.gz'
data_root_orig = tf.keras.utils.get_file('imagenetV2', im_url, cache_dir = './data/imagenetv2-topimages', extract= True)
data_root = pathlib.Path('./data/imagenetv2-topimages')
print(data_root)
1
data/imagenetv2-topimages
  • cache_dir을 설정하여 원하는 경로에 ImageNetV2를 다운할수 있음


1.6 데이터 경로의 이상유무 확인

1
2
3
4
for idx, item in enumerate(data_root.iterdir()):
    print(item)
    if idx ==5:
        break
1
2
3
4
5
6
data/imagenetv2-topimages/797
data/imagenetv2-topimages/909
data/imagenetv2-topimages/135
data/imagenetv2-topimages/307
data/imagenetv2-topimages/763
data/imagenetv2-topimages/551
  • 데이터셋을 확인했고, 큰 이상유무는 없는듯 싶다.


1.7 라벨 불러오기

1
2
3
4
5
6
7
8
label_url = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
label_file = tf.keras.utils.get_file('label', label_url)
label_text = None
with open(label_file, 'r') as f:
    label_text = f.read().split('\n')[:-1]
print(len(label_text))
print(label_text[:10])
print(label_text[-10:])
1
2
3
4
5
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt
16384/10484 [==============================================] - 0s 0us/step
1001
['background', 'tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead', 'electric ray', 'stingray', 'cock', 'hen']
['buckeye', 'coral fungus', 'agaric', 'gyromitra', 'stinkhorn', 'earthstar', 'hen-of-the-woods', 'bolete', 'ear', 'toilet tissue']


1.8 이미지 보기

1
2
3
4
5
6
7
8
import random

all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)

image_count = len(all_image_paths)
print(f'Image Count : :{image_count}')
1
Image Count : :10002
  • 총 이미지는 1만 2개 있음
  • 어떻게 생겨먹은 이미지 일까?


1
2
3
4
5
6
7
8
9
10
11
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 12))
for c in range(9):
    image_path = random.choice(all_image_paths)
    plt.subplot(3, 3, c + 1)
    plt.imshow(plt.imread(image_path))
    idx = int(image_path.split('/')[-2]) + 1
    plt.title(str(idx) +', ' + label_text[idx])
    plt.axis('off')
plt.show()

  • 이렇게 생겼다


1.9 Test

  • pip install opencv-python
  • pip install opencv-contrib-python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import cv2

img = random.choice(all_image_paths)
label = int(img.split('/')[-2]) + 1
img_draw = cv2.imread(img)
img_resized = cv2.resize(img_draw, dsize = (224, 224))
img_resized = img_resized / 255.0
img_resized = np.expand_dims(img_resized, axis = 0)
top_5_predict = model.predict(img_resized)[0].argsort()[::-1][:5]
print(top_5_predict)
print(label)
if label in top_5_predict:
    print('Anser is correct !!')
print(f'Predicted Answer is {label_text[label]}')

plt.imshow(plt.imread(img))
plt.show()
1
2
3
4
[827 427 636 836 713]
827
Anser is correct !!
Predicted Answer is stopwatch

  • 이미지를 아까 사전학습한 모델에 넣고 예측하기, 예측한것중에 상위 5개(확률)중 실제 라벨이 있다면 correct를 출력함


1.10 Softmax 계산 및 확률값 시각화

1
2
3
4
5
6
7
8
9
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis = 0)

logits = model.predict(img_resized)[0]
prediction = softmax(logits)

top_5_predict = prediction.argsort()[::-1][:5]
labels = [label_text[index] for index in top_5_predict]
  • SoftMax 계산하고, 상위 5개 라벨에 대한 확률값을 계산함


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)
plt.imshow(plt.imread(img))

idx = int(img.split('/')[-2]) + 1
plt.title(str(idx) + ', ' + label_text[idx])
plt.axis('off')

plt.subplot(1, 2, 2)
color = ['gray'] * 5
if idx in top_5_predict:
    color[top_5_predict.tolist().index(idx)] = 'green'
color = color[::-1]
plt.barh(range(5), prediction[top_5_predict][::-1] * 100, color = color)
plt.yticks(range(5), labels[::-1])
plt.show()

  • 해당 이미지에 상위 5개 라벨의 확률을 시각화화서 보여줌


2. 전이학습


2.1 전이학습 - Transfer Learning

  • 전이학습은 기존의 사전학습모델에서 일부 layer들을 가져와서 재사용하여 비슷한 모델을 생성함


2.2 Pre Trained 모델에 전이학습을 적용할 데이터

https://www.kaggle.com/c/dog-breed-identification/data

  • 강아지 사진과 강아지의 종이 있는 데이터
  • 사진을 보고, 강아지의 종을 맞춘다
  • 위의 링크에서 데이터 다운 가능
  • 편의를 위해 폴더이름을 dog_data로 변경


2.3 Label 데이터 확인

1
2
3
import pandas as pd
label_text = pd.read_csv('./data/dog_data/labels.csv')
print(label_text.head())
1
2
3
4
5
6
                                 id             breed
0  000bec180eb18c7604dcecc8fe0dba07       boston_bull
1  001513dfcb2ffafc82cccf4d8bbaba97             dingo
2  001cdf01b096e06d78e9e5112d419397          pekinese
3  00214f311d5d2247d5dfe4fe24b2303d          bluetick
4  0021f9ceb3235effd7fcde7f7538ed62  golden_retriever
  • label에는 id값과 종이 적혀있음


1
2
label_text.info()
print(label_text['breed'].nunique(),'장')
1
2
3
4
5
6
7
8
9
10
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10222 entries, 0 to 10221
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   id      10222 non-null  object
 1   breed   10222 non-null  object
dtypes: object(2)
memory usage: 159.8+ KB
120 장
  • 만장이 넘는 dog 데이터와 120종의 품종


2.4 데이터 시각화

1
2
3
4
5
6
7
8
plt.figure(figsize=(12, 8))
for c in range(6):
    image_id = label_text.loc[c, 'id']
    plt.subplot(2, 3, c + 1)
    plt.imshow(plt.imread('./data/dog_data/train/' + image_id + '.jpg'))
    plt.title(str(c) + ', ' + label_text.loc[c, 'breed'])
    plt.axis('off')
plt.show()

  • 이렇게 생긴 사진이다.


2.5 MobileNet V2 Load

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from tensorflow.keras.applications import MobileNetV2
import tensorflow as tf
mobilev2 = MobileNetV2()

x = mobilev2.layers[-2].output
predictions = tf.keras.layers.Dense(120, activation='softmax')(x)
model = tf.keras.Model(inputs=mobilev2.input, outputs=predictions)

for layer in model.layers[:-20]:
    layer.trainable = False
for layer in model.layers[-20:]:
    layer.trainable = True

model.compile(optimizer='sgd',
              loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 225, 225, 3)  0           input_2[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 112, 112, 32) 864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 112, 112, 32) 128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 112, 112, 32) 0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 112, 112, 32) 288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 112, 112, 32) 128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 112, 112, 32) 0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 112, 112, 16) 512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 112, 112, 16) 64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 112, 112, 96) 1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 112, 112, 96) 384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 112, 112, 96) 0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 113, 113, 96) 0           block_1_expand_relu[0][0]        
...
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 7, 7, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 7, 7, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 7, 7, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 7, 7, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 7, 7, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 7, 7, 1280)   0           Conv_1_bn[0][0]                  
__________________________________________________________________________________________________
global_average_pooling2d_1 (Glo (None, 1280)         0           out_relu[0][0]                   
__________________________________________________________________________________________________
dense (Dense)                   (None, 120)          153720      global_average_pooling2d_1[0][0] 
==================================================================================================
Total params: 2,411,704
Trainable params: 1,204,280
Non-trainable params: 1,207,424
__________________________________________________________________________________________________
  • 네트워크 구조의 마지막 20개 레이어만 학습하도록 설정
  • Total params의 갯수와 Trainable params의 갯수가 다른건, 20개의 레이어만 사용하기로 했기 떄문
  • 그래도 summary를 보면 뭔가가 엄청 많다..


2.11 Train X, Y Data 생성

1
2
3
4
5
6
7
8
9
10
11
import cv2

train_X = []
for i in range(len(label_text)):
    img = cv2.imread('./data/dog_data/train/' + label_text['id'][i] + '.jpg')
    img = cv2.resize(img, dsize = (224, 224))
    img = img / 255.0
    train_X.append(img)
train_X = np.array(train_X)
print(train_X.shape)
print(train_X.size * train_X.itemsize , ' bytes')
1
2
(10222, 224, 224, 3)
12309577728  bytes


1
2
3
4
5
unique_Y = label_text['breed'].unique().tolist()
train_Y = [unique_Y.index(breed) for breed in label_text['breed']]
train_Y = np.array(train_Y)
print(train_Y[:10])
print(train_Y[-10:])
1
2
[0 1 2 3 4 5 5 6 7 8]
[34 87 91 63 48  6 93 63 77 92]


2.12 Fit

1
history = model.fit(train_X, train_Y, epochs=10, validation_split=0.25, batch_size = 32)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Epoch 1/10
240/240 [==============================] - 102s 424ms/step - loss: 3.2485 - accuracy: 0.2904 - val_loss: 1.9760 - val_accuracy: 0.4534
Epoch 2/10
240/240 [==============================] - 107s 445ms/step - loss: 1.6241 - accuracy: 0.6054 - val_loss: 1.6033 - val_accuracy: 0.5477
Epoch 3/10
240/240 [==============================] - 108s 449ms/step - loss: 1.1931 - accuracy: 0.7131 - val_loss: 1.4534 - val_accuracy: 0.5943
Epoch 4/10
240/240 [==============================] - 108s 449ms/step - loss: 0.9406 - accuracy: 0.7747 - val_loss: 1.3683 - val_accuracy: 0.6072
Epoch 5/10
240/240 [==============================] - 106s 441ms/step - loss: 0.7763 - accuracy: 0.8298 - val_loss: 1.3273 - val_accuracy: 0.6217
Epoch 6/10
240/240 [==============================] - 114s 474ms/step - loss: 0.6375 - accuracy: 0.8729 - val_loss: 1.3028 - val_accuracy: 0.6295
Epoch 7/10
240/240 [==============================] - 107s 447ms/step - loss: 0.5297 - accuracy: 0.9023 - val_loss: 1.2905 - val_accuracy: 0.6307
Epoch 8/10
240/240 [==============================] - 107s 446ms/step - loss: 0.4413 - accuracy: 0.9314 - val_loss: 1.2837 - val_accuracy: 0.6405
Epoch 9/10
240/240 [==============================] - 108s 448ms/step - loss: 0.3795 - accuracy: 0.9510 - val_loss: 1.2573 - val_accuracy: 0.6518
Epoch 10/10
240/240 [==============================] - 108s 449ms/step - loss: 0.3188 - accuracy: 0.9622 - val_loss: 1.2633 - val_accuracy: 0.6424
  • 전이학습중, 전이학습이 아니라면 시간은 훨씬 많이 들어간다.


2.13 학습상황 확인

1
2
3
4
5
6
7
8
plt.figure(figsize = (12, 8))
plt.plot(history.history['loss'], 'b-', label = 'loss')
plt.plot(history.history['val_loss'], 'r--', label = 'val_loss')
plt.plot(history.history['accuracy'], 'g-', label = 'accuracy')
plt.plot(history.history['val_accuracy'], 'k--', label = 'val_accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.show()

  • 일단 과적합 처럼 보이긴하나, 학습을 오래한것은 아니기에..잘은 모르겠다.
This post is licensed under CC BY 4.0 by the author.