1. 체크포인트 파일 저장하기
최적화가 끝난 후 학습된 변수들을 체크포인트 파일에 저장한다.
save.save(sess, './model/myNetwork.ckpt', global_step = total_step)
두 번째 인수는 체크포인트 파일이 저장될 위치와 체크포인트 파일의 이름을 의미하며, 체크포인트가 저장될 디렉토리(이 경우는 './model')는 미리 만들어져 있어야 한다.
세 번째 인수의 global_step의 값은 텐서 변수 또는 숫자 값을 넣을 수 있으며, 저장할 파일의 이름에 추가로 붙게되어 여러 상태의 체크포인트를 만들 수 있고, 가장 효과적인 모델을 선별하여 사용할 수 있다.
2. 체크포인트 파일 불러오기
./model이라는 디렉토리에 기존에 학습해둔 파일이 있는지 확인한다.
ckpt = tf.train.get_checkpoint_state('./model')
모델이 있다면 saver.restore() 함수를 이용하여 학습된 값들을 불러오고, 아니면 변수를 새로 초기화한다.
if ckpt and tf.train.checkpoint_exist(ckpt.model_checkpoint_path):
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
3. 주의사항(?)
나만 멍청한 짓을 했나보다. 책을 보고 따라하다보니 체크포인트 파일을 저장하는데는 문제가 없었다. 일단 학습을 끝내고 체크포인트 파일에 저장하니 해당 폴더에 체크포인트 파일이 저장된 것을 확인할 수 있었다. 아래와 같이 첫 번째 실행은 정상적으로 잘 동작하였다.
runfile('C:/Users/Thriving_Zinnias/AnacondaProjects/spyder-py3/golbin_checkpoint.py', wdir='C:/Users/Thriving_Zinnias/AnacondaProjects/spyder-py3')
Step: 1 Cost: 0.735
Step: 2 Cost: 0.708
Prediction Value : [0 1 2 0 0 2]
Target Value : [0 1 2 0 0 2]
Accuracy : 100.00
이제 저장된 체크포인트 파일을 불러와서 모델을 실행시켜보고자 스크립를 'run'하였다.
스크립트의 두 번째 실행이었다. 아래와 같은 Error가 반겨주었다.
INFO:tensorflow:Restoring parameters from ./model\dnn.ckpt-2
Traceback (most recent call last):
File "<ipython-input-4-b4971a68fef7>", line 1, in <module>
runfile('C:/Users/Thriving_Zinnias/AnacondaProjects/spyder-py3/golbin_checkpoint.py', wdir='C:/Users/Thriving_Zinnias/AnacondaProjects/spyder-py3')
<중략>
File "C:\Users\Thriving_Zinnias\Anaconda3\envs\TF-35\lib\site-packages\tensorflow\python\client\session.py", line 1152, in _do_callraise type(e)(node_def, op, message)
NotFoundError: Key global_step_2 not found in checkpoint
[[Node: save_2/RestoreV2_35 = RestoreV2[dtypes=[DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save_2/Const_0_0, save_2/RestoreV2_35/tensor_names, save_2/RestoreV2_35/shape_and_slices)]]
Caused by op 'save_2/RestoreV2_35', defined at:
File "C:\Users\Thriving_Zinnias\Anaconda3\envs\TF-35\lib\site-packages\spyder\utils\ipython\start_kernel.py", line 241, in <module>
main()
<중략>
File "C:\Users\Thriving_Zinnias\Anaconda3\envs\TF-35\lib\site-packages\tensorflow\python\framework\ops.py", line 1269, in __init__self._traceback = _extract_stack()
NotFoundError (see above for traceback): Key global_step_2 not found in checkpoint
[[Node: save_2/RestoreV2_35 = RestoreV2[dtypes=[DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save_2/Const_0_0, save_2/RestoreV2_35/tensor_names, save_2/RestoreV2_35/shape_and_slices)]]
이것때문에 몇시간을 허비하였다. 문제는 스크립트를 실행시키면 해당 내용이 커널에 계속 남아 있는다는 것이다. 이게 문제일거라고는 생각도 못했다.
커널을 재시작하고, 스크립트를 'run'하였더니 아래와 같이 잘 되었다. 모델을 불러와서 정상적으로 동작하면 화면에 나타나게 했던 'Trained model has been loaded.'가 보인다.
runfile('C:/Users/Thriving_Zinnias/AnacondaProjects/spyder-py3/golbin_checkpoint.py', wdir='C:/Users/Thriving_Zinnias/AnacondaProjects/spyder-py3')
INFO:tensorflow:Restoring parameters from ./model\dnn.ckpt-2
Trained model has been loaded.
Step: 3 Cost: 0.704
Step: 4 Cost: 0.680
Prediction Value : [0 1 2 0 0 2]
Target Value : [0 1 2 0 0 2]
Accuracy : 100.00
댓글 없음:
댓글 쓰기