Short Post: Keras progress bars in Pytorch code
I really like the progress bars in keras API. So, this is how to add those progress bars in any python code.
Example with pytorch training loop.
import tensorflow as tf
n_epochs = 3
for epoch in range(n_epochs):
n_batches = len(dataloader)
print(f'Epoch {epoch+1}/{n_epochs}')
pbar = tf.keras.utils.Progbar(target=n_batches)
for idx, batch in enumerate(dataloader):
train_loss = train_step(batch)
pbar.update(idx, values=[("loss",train_loss)])
val_loss = validate()
pbar.update(n_batches, values=[('val_loss', val_loss)])
Output:
Epoch 1/3
10/10 [==============================] - 2s 225ms/step - loss: 0.4698 - val_loss: 0.9111
Epoch 2/3
10/10 [==============================] - 2s 226ms/step - loss: 0.4626 - val_loss: 0.8433
Epoch 3/3
10/10 [==============================] - 2s 225ms/step - loss: 0.5835 - val_loss: 0.4997
Notes:
- The metrics shown along with the progress bar are being averaged over each update.
So if the loss values for each batch over an epoch were [5,4,3,2,1], you would see the values [5, 4.5, 4, 3.5, 3] as the progress bar fills up. If you need the raw values, you need to add an argument to the instantiation of the progress bar. e.g.This would make the progress bar not average the values corresponding to the keypbar = tf.keras.utils.Progbar(target=n_batches, stateful_metrics=['loss'])
loss
during each update. - It's important to "finish" the progress bar by having the final
update
function contain the first argument as thetarget
or total number of iterations set in the progress bar. e.g. The total in the progress bar was the number of batches orn_batches
. So, if you were skipping the validation step, you would still need to callafter all the training batches were processed, to finish the progress bar. And it would only show the training loss this time.pbar.update(n_batches, values=None)
Failing to do so will result in incomplete progress bars for each iterations that look like this.Epoch 1/3 9/10 [==========================>...] - ETA: 0s - loss: 0.4153Epoch 2/3 9/10 [==========================>...] - ETA: 0s - loss: 0.4637Epoch 3/3 9/10 [==========================>...] - ETA: 0s - loss: 0.5152