Preservation and recovery of TF. Train. Saver () model of tensorflow

The trained model parameters are saved for later verification or testing. The TF. Train. Saver () module provides model storage in TF

To save a model, first create a saver object, such as

saver=tf.train.Saver()

When creating this Saver object, we often use one parameter, which is max_ to_ The keep parameter is used to set the number of saved models. The default value is 5, that is, max_ to_ Keep = 5, save the latest 5 models. If you want to save the model every epoch, you can save max_ to_ Keep is set to none or 0, for example:

saver=tf.train.Saver(max_to_keep=0)

However, in addition to occupying more hard disk, it is not of much practical use, so it is not recommended

Of course, if you only want to save the last generation of models, you just need to save max_ to_ Keep is set to 1, that is to say

saver=tf.train.Saver(max_to_keep=1)

After creating the saver object, you can save the trained model, such as:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

The first parameter sess, needless to say. The second parameter sets the saved path and name, and the third parameter adds the number of training times as a suffix to the model name

saver.save(sess, ‘my-model’, global_ step=0) ==> filename: ‘my-model-0’

saver.save(sess, ‘my-model’, global_ step=1000) ==> filename: ‘my-model-1000’

2. Examples

import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''
saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if isTrain:
        for i in xrange(train_steps):
            sess.run(train, feed_dict={x: x_data})
            if (i + 1) % checkpoint_steps == 0:
                saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
    else:
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            pass
        print(sess.run(w))
        print(sess.run(b)) 

3. Recovery

Use the saver. Restore() method to recover variables

saver.restore(sess,'ckpt.model_checkpoint_path')

Sess: indicates the current session, and the previously saved results will be loaded into this session

ckpt.model_ checkpoint_ Path: indicates the storage location of the model. It does not need to provide the name of the model. It will check the checkpoint file to see who is the latest and what is its name

Reprinted:

【1】 https://www.cnblogs.com/denny402/p/6940134.html

【2】 https://blog.csdn.net/u011500062/article/details/51728830

【3】 https://www.cnblogs.com/chamie/p/8780508.html

Similar Posts: