Learn practical skills, build real-world projects, and advance your career

Five gradient/derivates related pytorch functions

Below are the 5 pytorch functions which mostly deals with gradient calculations and manipulations

PyTorch is an open source machine learning library based on the Torch library used for applications such as computer vision and natural language processing. It has been primarily developed by Facebook's AI Research lab (FAIR).It is free and open-source software released under the Modified BSD license. Although the Python interface is more polished and the primary focus of development, PyTorch also has a C++ interface

Following are the 5 functions I will be talking about.

  • detach()
  • no_grad()
  • clone()
  • backward()
  • register_hook()
# Import torch and other required modules
import torch

1. tensor.detach()

tensor.detach() creates a tensor that shares storage with tensor that does not require grad. You should use detach() when attempting to remove a tensor from a computation graph.
In order to enable automatic differentiation, PyTorch keeps track of all operations involving tensors for which the gradient may need to be computed (i.e., require_grad is True). The operations are recorded as a directed graph. The detach() method constructs a new view on a tensor which is declared not to need gradients, i.e., it is to be excluded from further tracking of operations, and therefore the subgraph involving this view is not recorded.

#Below is a situation when we are not using detach()
x=torch.ones(10, requires_grad=True)

y=x**2
z=x**3

r=(y+z).sum()
r.backward()
print(x.grad)
tensor([5., 5., 5., 5., 5., 5., 5., 5., 5., 5.])
#Situation when we are using detach()
x=torch.ones(10, requires_grad=True)

y=x**2
z=x.detach()**3

r=(y+z).sum()
r.backward()

print(x.grad)
tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])