Learn practical skills, build real-world projects, and advance your career

PyTorch Basics

Introduction

PyTorch is an open source machine learning library based on the Torch library used for machine learning and deep learning.

  • torch.eye
  • torch.squeeze
  • torch.unsqueeze
  • torch.where
  • torch.cat
# Import torch and other required modules
import torch

Function 1 - eye

Returns a 2-D tensor with ones on the diagonal and zeros elsewhere.

a=torch.eye(5)
a
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])

Here, a 5 by 5 matrix is created with diagonal elements equal to 1 and others are initialized to 0.