How to write a fit function for a machine learning model?

@PrajwalPrashanth I don´t understand how the fit function train our model.
In scikit-learn we use model.fit() and it does the trick but here IDK how this works.
In my assignment I used a cicle to test different configurations but I’m not pretty sure what version does the model takes as final.

1 Like

@cnarvaa

Fit

  1. Takes input data and corresponding labels (truth values from dataloader)
  2. Takes predictions by passing input data to the model
  3. Compare predictions with the labels (This is loss calculation)
  4. Calculate gradients
  5. Make weight updates based on the calculated gradients and learning rate.
  6. Reset gradients for next set of calculations

Repeat 1-6 based on no. of epochs.

This is the gist of Fit function, if you can elaborate your question and be more specific on your doubt, i can explain them in more detail.

Courtesy- @PrajwalPrashanth