MNIST 데이터를 통해 간단한 분류 모델을 생성하며 텐서플로를 실습해 본다.
MNIST 데이터
MINST 데이터는 손으로 쓴 0~9까지의 숫자 데이터로 훈련을 위한 6만개의 데이터와 테스트를 위한 1만개의 뎉이터로 이루어져 있따. 머신 러닝 알고리즘을 학습하기 좋은 간편한 데이터 집합이다.
각 이미지는 28x28의 크기이며 0~9까지의 숫자를 표시하기 위한 라벨은 0~9의 값으로 표현되어 있다.
위의 링크에서 훈련용 이미지/라벨과 테스트용 이미지/라벨 파일을 볼 수 있다. 이미지 데이터와 라벨 데이터를 표현하는 파일이 각각 존재한다. 파일 형식은 MSB(Most Significant Bit) 방식으로 저장되어 있다.
TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
라벨의 value는 0에서 9까지의 숫자이다.
xxxxxxxxxx
TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
픽셀들은 row-wise로 구성되어 있다. 픽셀 값은 0에서 255이다. 0은 배경(흰색), 255는 전경색(검정색)을 의미한다. 하지만 텐서플로에는 이 MNIST 데이터셋이 포함되어 있기 떄문에 간편히 이용할 수 있다.
텐서플로 실습
MNIST 데이터를 다운받기 위해 다음과 같은 코드를 실행한다.
xfrom tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets('MNIST_data', one_hot = True);
다운로드가 완료되면 숫자를 분류하기 위한 퍼셉트론을 구축한다.
xxxxxxxxxx
import tensorflow as tf
input_size = 784
num_of_class = 10
batch_size = 100
total_batches = 200
플레이스 홀더는 데이터가 입출력되는 통로이다. 플레이스 홀더는 특정 값은 아니지만 계산 과정에서 입력을 받는다. 퍼셉트론의 입력 크기, 클래스 수, 배치 크기, 반복/배치의 총 개수를 선언해 둔다. input_size
가 784인 것은 MNIST 데이터에서 이미지가 28*28 형태이기 때문이다. 0~9까지 숫자 분류이기 떄문에 num_of_class
는 10이다.
xxxxxxxxxx
x_input = tf.placeholder(tf.float32, shape=[None, input_size])
y_input = tf.placeholder(tf.float32, shape=[None, no_classes])
shape
인수의 None
은 아직 배치 크기를 정의하지 않았으므로 어떤 크기든 될 수 있음을 의미한다. x_input
의 두 번째 인수는 텐서의 크기이며 y_input
의 두 번쨰 인수는 클래스 수이다. 또한 tf.float32
를 통해 데이터를 실수 형태로 전달한다. 이제 퍼셉트론을 정의할 수 있다.
xxxxxxxxxx
weights = tf.Variable(tf.random_normal([input_size, no_classes]))
bias = tf.Variable(tf.random_normal([no_classes]))
텐서플로에서 Variable은 변하는 값을 지정해주는 것이다. 학습할 때 변하는 값은 학습 시키고 싶은 파라미터(뉴럴넷의 weight, CNN의 필터 등..)이다. 이 값은 초기화를 해주는데 0으로 할 수 있지만 일반적으로 가우시안 분포에서 뽑게된다.
가중치 변수인 weights
는 입력 크기의 shape 및 클래스 수의 정규 분포를 따르는 임의 값으로 초기화한다. 편향 변수 bias
또한 클래스 수와 동일한 크기의 정규 분포를 따르는 임의 값으로 초기화한다.
xxxxxxxxxx
logits = tf.matmul(x_input, weights) + bias
Y = W * X + bias
뉴렬을 계산하는 오퍼레이터를 정의해준다. x와 weight는 행렬의 형태로, 텐서플로의 matmul()
을 통해 곱셈 연산을 지정할 수 있다.
xxxxxxxxxx
softmax_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_input, logits=logits)
loss_operation = tf.reduce_mean(softmax_cross_entropy)
optimiser = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss_operation)
퍼셉트론으로 생성된 logit의 결과는 one-hot으로 표현된 y_input
과 비교되어야 한다. 이 때 교차 엔트로피와 소프트맥스를 사용한다. 텐서플로의 softmax_cross_entropy_with_logits
API가 이 작업을 수행한다.
모델과 훈련 작업의 정의가 완료되면 데이터로 모델을 훈련할 수 있따. 훈련 과정에서 기울기를 구하고 가중치 업데이트가 일어난다.
xxxxxxxxxx
session = tf.Session()
session.run(tf.global_variables_initializer()
위 코드처럼 세션을 시작하고 전역 변수 초기화 함수를 통해 변수를 초기화한다.
이제 배치의 데이터를 반복해서 읽고 모델을 훈련시킨다.
xxxxxxxxxx
for batch_no in range(total_batches):
mnist_batch = mnist_data.train.next_batch(batch_size)
train_images, train_labels = mnist_batch[0], mnist_batch[1]
_, loss_value = session.run([optimiser, loss_operation],
feed_dict={x_input: train_images, y_input: train_labels})
print("[%3d] loss_value : %lf" % (batch_no, loss_value))
run()
을 통해 세션을 실행함으로써 모델 훈련을 실행하고 그래프가 가중치를 업데이트할 수 있도록 최적화 알고리즘이 호출되어야 한다. feed_dic
은 플레이스홀더에 입력 레이블과 타깃 레이블을 직접적으로 입력하기 위해 사용된다.
loss_value의 출력을 보면 0단계에서는 두자리 숫자에서 시작하지만 배치 200번이 끝나면 0~1의 값으로 줄어드는 것을 확인할 수 있다.
predictions = tf.argmax(logits, 1)
correct_predictions = tf.equal(predictions, tf.argmax(y_input, 1))
accuracy_operation = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
test_images, test_labels = mnist_data.test.images, mnist_data.test.labels
accuracy_value = session.run(accuracy_operation,
feed_dict={x_input: test_images, y_input: test_labels})
print('Accuracy : ', accuracy_value)
session.close()
정확도를 계산해 모델이 얼마나 잘 동작하는지 확인할 수 있다. 동일하게 run()
을 실행하며 feed_dict
에는 테스트 데이터가 주어진다. 이를 통해 생성한 모델의 정확도를 확인할 수 있다.