The trained model parameters are saved for later verification or testing. The TF. Train. Saver () module provides model storage in TF
To save a model, first create a saver object, such as
saver=tf.train.Saver()
When creating this Saver object, we often use one parameter, which is max_ to_ The keep parameter is used to set the number of saved models. The default value is 5, that is, max_ to_ Keep = 5, save the latest 5 models. If you want to save the model every epoch, you can save max_ to_ Keep is set to none or 0, for example:
saver=tf.train.Saver(max_to_keep=0)
However, in addition to occupying more hard disk, it is not of much practical use, so it is not recommended
Of course, if you only want to save the last generation of models, you just need to save max_ to_ Keep is set to 1, that is to say
saver=tf.train.Saver(max_to_keep=1)
After creating the saver object, you can save the trained model, such as:
saver.save(sess,'ckpt/mnist.ckpt',global_step=step)
The first parameter sess, needless to say. The second parameter sets the saved path and name, and the third parameter adds the number of training times as a suffix to the model name
saver.save(sess, ‘my-model’, global_ step=0) ==> filename: ‘my-model-0’
…
saver.save(sess, ‘my-model’, global_ step=1000) ==> filename: ‘my-model-1000’
2. Examples
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if isTrain:
for i in xrange(train_steps):
sess.run(train, feed_dict={x: x_data})
if (i + 1) % checkpoint_steps == 0:
saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b))
3. Recovery
Use the saver. Restore() method to recover variables
saver.restore(sess,'ckpt.model_checkpoint_path')
Sess: indicates the current session, and the previously saved results will be loaded into this session
ckpt.model_ checkpoint_ Path: indicates the storage location of the model. It does not need to provide the name of the model. It will check the checkpoint file to see who is the latest and what is its name
Reprinted:
【1】 https://www.cnblogs.com/denny402/p/6940134.html
【2】 https://blog.csdn.net/u011500062/article/details/51728830
【3】 https://www.cnblogs.com/chamie/p/8780508.html