티스토리 뷰
2. Neural Networks / L8. TensorFlow - Save and Restore TensorFlow Models
chrisysl 2018. 8. 28. 21:27Save and Restore TensorFlow Models
- 모델을 학습하는데는 몇시간이 걸릴 수 있다. 그 도중에 텐서플로우 세션을 닫아버리면
- 학습 된 모든 weights와 bias들을 잃게된다.
- 이렇게 잃어버리면 모델을 다시 학습시켜야한다.
- 다행히 이런 문제를 방지하기위해 텐서플로우는 tf.train.Saver라는 클래스를 제공한다.
- 이 클래스를 사용하면 진행상황을 저장할 수 있다.
- 다시말해 모든 tf.Variable을 file system에 저장하는 기능을 한다.
Saving Variables
- 먼저 weights와 bias 텐서를 저장하는 예시부터 진행해보자.
- 먼저 두 변수를 저장해보고, 나중엔 실제 모델에서 모든 가중치들을 저장해보도록 하자.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | import tensorflow as tf # The file path to save the data save_file = './model.ckpt' # Two Tensor Variables: weights and bias weights = tf.Variable(tf.truncated_normal([2, 3])) bias = tf.Variable(tf.truncated_normal([3])) # Class used to save and/or restore Tensor Variables saver = tf.train.Saver() with tf.Session() as sess: # Initialize all the Variables sess.run(tf.global_variables_initializer()) # Show the values of weights and bias print('Weights:') print(sess.run(weights)) print('Bias:') print(sess.run(bias)) # Save the model saver.save(sess, save_file) | cs |
- 텐서의 가중치와 바이어스는 tf.truncated_normal() 함수를 사용하여 임의의 값으로 설정된다.
- 값은 tf.train.Saver.save() 함수를 사용하여 save_file 위치에 저장된다.
- .ckpt 확장자는 "체크포인트(check point)"를 의미한다.
- 텐서플로우 0.11.0RC1 또는 그 이상의 버전을 사용하는경우 메타파일도 같이 생성된다.
- 이 메타파일에는 텐서플로우 그래프가 포함되어있다.
Loading Variables
- 저장이 끝났다면 이 텐서변수들을 다시 모델로 불러들이는 작업을 해보도록 하자.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | # Remove the previous weights and bias tf.reset_default_graph() # Two Variables: weights and bias weights = tf.Variable(tf.truncated_normal([2, 3])) bias = tf.Variable(tf.truncated_normal([3])) # Class used to save and/or restore Tensor Variables saver = tf.train.Saver() with tf.Session() as sess: # Load the weights and bias saver.restore(sess, save_file) # Show the values of weights and bias print('Weight:') print(sess.run(weights)) print('Bias:') print(sess.run(bias)) | cs |
- 보면 가중치와 바이어스 텐서를 생성해 두고, 거기에 불러와야함을 알 수 있다.
- tf.train.Saver.restore() 함수는 저장된 데이터를 가중치와 바이어스에 로드하는 역할만 하기 때문.
- tf.train.Saver.restore() 함수는 모든 텐서플로우 변수를 설정하기때문에
- tf.global_variables_initializer() 함수를 호출하여 초기화 할 필요가 없다.
>>
Save a Trained Model
- 이번엔 학습된 모델을 저장하는 방법에 대해 알아보자.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | # Remove previous Tensors and Operations tf.reset_default_graph() from tensorflow.examples.tutorials.mnist import input_data import numpy as np learning_rate = 0.001 n_input = 784 # MNIST data input (img shape: 28*28) n_classes = 10 # MNIST total classes (0-9 digits) # Import MNIST data mnist = input_data.read_data_sets('.', one_hot=True) # Features and Labels features = tf.placeholder(tf.float32, [None, n_input]) labels = tf.placeholder(tf.float32, [None, n_classes]) # Weights & bias weights = tf.Variable(tf.random_normal([n_input, n_classes])) bias = tf.Variable(tf.random_normal([n_classes])) # Logits - xW + b logits = tf.add(tf.matmul(features, weights), bias) # Define loss and optimizer cost = tf.reduce_mean(\ tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)) optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\ .minimize(cost) # Calculate accuracy correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | cs |
- 그다음 모델을 학습시키고 학습시킨 가중치를 저장해보자.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import math save_file = './train_model.ckpt' batch_size = 128 n_epochs = 100 saver = tf.train.Saver() # Launch the graph with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # Training cycle for epoch in range(n_epochs): total_batch = math.ceil(mnist.train.num_examples / batch_size) # Loop over all batches for i in range(total_batch): batch_features, batch_labels = mnist.train.next_batch(batch_size) sess.run( optimizer, feed_dict={features: batch_features, labels: batch_labels}) # Print status for every 10 epochs if epoch % 10 == 0: valid_accuracy = sess.run( accuracy, feed_dict={ features: mnist.validation.images, labels: mnist.validation.labels}) print('Epoch {:<3} - Validation Accuracy: {}'.format( epoch, valid_accuracy)) # Save the model saver.save(sess, save_file) print('Trained Model Saved.') | cs |
Load a Trained Model
- 메모리로부터 학습된 모델의 가중치와 바이어스를 로드하려면 다음과같다.
1 2 3 4 5 6 7 8 9 10 11 | saver = tf.train.Saver() # Launch the graph with tf.Session() as sess: saver.restore(sess, save_file) test_accuracy = sess.run( accuracy, feed_dict={features: mnist.test.images, labels: mnist.test.labels}) print('Test Accuracy: {}'.format(test_accuracy)) | cs |
>>
'Deep Learning' 카테고리의 다른 글
2. Neural Networks / L8. TensorFlow - TensorFlow Dropout (0) | 2018.08.30 |
---|---|
2. Neural Networks / L8. TensorFlow - Finetuning (0) | 2018.08.28 |
2. Neural Networks / L8. TensorFlow - Deep Neural Network in TensorFlow (0) | 2018.08.27 |
2. Neural Networks / L8. TensorFlow - Multilayer Neural Networks (0) | 2018.08.23 |
2. Neural Networks / L8. TensorFlow - Lab: NotMNIST in TensorFlow (0) | 2018.08.21 |