# Python: why use it optimizer.zero_ grad()

optimizer.zero_ Grad () means to set the gradient to zero, that is, to change the derivative of loss with respect to weight to 0

When learning python, I noticed that for each batch, most of the operations are as follows:

``````# zero the parameter gradients
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
``````

For these operations, I understand it as a gradient descent method, and paste a simple gradient descent method that I wrote before as a contrast

``````    # gradient descent
weights = [0] * n
alpha = 0.0001
max_Iter = 50000
for i in range(max_Iter):
loss = 0
d_weights = [0] * n
for k in range(m):
h = dot(input[k], weights)
d_weights = [d_weights[j] + (label[k] - h) * input[k][j] for j in range(n)]
loss += (label[k] - h) * (label[k] - h)/2
d_weights = [d_weights[k]/m for k in range(n)]
weights = [weights[k] + alpha * d_weights[k] for k in range(n)]
if i%10000 == 0:
print "Iteration %d loss: %f"%(i, loss/m)
print weights
``````

It can be found that they are actually one-to-one correspondence

optimizer.zero_ Grad () corresponds to d_ weights = [0] * n

That is to initialize the gradient to zero (because the derivative of loss of a batch with respect to weight is the sum of all the derivative of loss of sample with respect to weight)

outputs = net (inputs) corresponds to h = dot (input [k], weights)

That is, the predicted value can be obtained by forward propagation

loss = criterion (outputs, labels) corresponds to loss + = (label [k] – H) * (label [k] – H)/2

This step is obviously to find loss (actually, I don’t think it’s OK to use this step. We can’t use the loss value in back propagation, just to let us know what the current loss is)
loss.backward () corresponds to d_ weights = [d_ weights[j] + (label[k] – h) * input[k][j] for j in range(n)]

That is to say, back propagation is used to find gradient optimizer.step () corresponding weights = [weights [k] + alpha * D_ weights[k] for k in range(n)]

All parameters are updated

If there is any mistake, please point out. Welcome to exchange