티스토리 뷰
Loading the Weights and Biases into a New Model
- 이미 학습하여 저장한 모델에 대해 미세한 조정(finetuning)이 가능하다.
- 하지만 저장된 변수를 직접적으로 수정된 모델에 대입하는것은 에러를 유발할 수 있다.
- 이런 문제점을 피하며 finetuning 하는법에대해 알아보자.
Naming Error
- 텐서플로우는 텐서와 연산(Tensors and Operations)에 있어 스트링 식별자로 name이라는 키워드를 사용한다.
- 만약 name이 주어지지 않으면 자동적으로 이를 생성한다.
- 텐서플로우는 첫번째 노드 이름을 <Type>으로 지정하고
- 바로 뒤따르는 두번째 노드로 이름을 <Type>_<number>로 지정한다.
- 이렇게 name을 할당하는것이 각기 다른순서의 가중치와 바이어스에 어떻게 영향을 미치는지 확인해보자.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import tensorflow as tf # Remove the previous weights and bias tf.reset_default_graph() save_file = 'model.ckpt' # Two Tensor Variables: weights and bias weights = tf.Variable(tf.truncated_normal([2, 3])) bias = tf.Variable(tf.truncated_normal([3])) saver = tf.train.Saver() # Print the name of Weights and Bias print('Save Weights: {}'.format(weights.name)) print('Save Bias: {}'.format(bias.name)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess, save_file) # Remove the previous weights and bias tf.reset_default_graph() # Two Variables: weights and bias bias = tf.Variable(tf.truncated_normal([3])) weights = tf.Variable(tf.truncated_normal([2, 3])) saver = tf.train.Saver() # Print the name of Weights and Bias print('Load Weights: {}'.format(weights.name)) print('Load Bias: {}'.format(bias.name)) with tf.Session() as sess: # Load the weights and bias - ERROR saver.restore(sess, save_file) | cs |
>>
- 위와같이 결과가 나오게 되는데, 가중치와 바이어스의 name property가 모델을 저장했을때와 달라졌음을 확인할 수 있다.
- 그렇기 때문에 "Assign to both tensors of shapes" 오류를 리턴하는 것이다.
- saver.restore(sess, save_file)에서 weight 데이터를 bias에, bias 데이터를 weight에 대입하고있기 때문.
- 따라서 텐서플로우가 알아서 name property를 정하게 하지말고, 직접 할당해보자.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | import tensorflow as tf tf.reset_default_graph() save_file = 'model.ckpt' # Two Tensor Variables: weights and bias weights = tf.Variable(tf.truncated_normal([2, 3]), name='weights_0') bias = tf.Variable(tf.truncated_normal([3]), name='bias_0') saver = tf.train.Saver() # Print the name of Weights and Bias print('Save Weights: {}'.format(weights.name)) print('Save Bias: {}'.format(bias.name)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess, save_file) # Remove the previous weights and bias tf.reset_default_graph() # Two Variables: weights and bias bias = tf.Variable(tf.truncated_normal([3]), name='bias_0') weights = tf.Variable(tf.truncated_normal([2, 3]) ,name='weights_0') saver = tf.train.Saver() # Print the name of Weights and Bias print('Load Weights: {}'.format(weights.name)) print('Load Bias: {}'.format(bias.name)) with tf.Session() as sess: # Load the weights and bias - No Error saver.restore(sess, save_file) print('Loaded Weights and Bias successfully.') | cs |
- 위와같이 name property를 설정해주면 텐서가 자동으로 할당하여 에러가 나는 경우를 방지할 수 있다.
>>