Category Archives: Python

Fixing Keras Hangups in Jupyter Notebooks

Keras is a wonderful Python library for high level implementation of deep learning networks. It provides a neat customizable interface for designing intricate sequential and recurrent neural networks and a fine grained control on the training algorithm. For its backend, it transparently allows the usage of either Theano or Tensorflow which seamlessly abstracts the CPU and GPU implementations of the complicated algorithms and data flows. Modern compiler design and expression templates have really come a long way!

I love programming in Jupyter notebooks because it leads to a reproducible record of my work, and it provides a very convenient interface to run your code on a headless work machine or in the AWS/Google cloud. Jupyter notebooks are also used heavily in machine learning courses taught online and in classrooms, because they help the instructor abstract the ugly details of setting up the environment into VMs or cloud images.

One frequently encountered problem with training Keras models in Jupyter notebooks is of getting a ‘WebSocket ping timeout’. My understanding of the issue is that the training progress bar updates overwhelm the Jupyter client-server connection and communication freezes. This is an often referenced issue, and some of the solutions involve redirecting the stdout to a file to relieve the stress of progress bar updates, but those prevent you from looking at the training progress and important messages right in the notebook. One elegant solution I like is disabling the default text progressbar in Keras and using the keras_tqdm progressbar. tqdm is a neat modern looking progressbar with Jupyter notebook support, so they don’t time out the connection with constant updates. The author has put together a really convenient Keras callback class that draws and updates the progress bars in the notebook. I have successfully used it to fix my timeout issues when training Keras models.

One little improvement that I contributed is a bug fix to support dynamic batch size training. Keras provides two ‘fit’ functions for training: a fit() function for fixed size batches where all data is loaded in a large numpy array, and a fit_generator() function where batches are generated on the fly from a generator. Batch generators have several advantages, the most important one being that datasets too large to fit into memory can be processed by reading them from the disk in chunks. Another significant advantage is being able to ‘generate’ data – for example, by applying transforms and crops to images, by adding noise, or by raw synthesis. The generators do not need to abide by a fixed batch size; they can yield batches with a different number of items every time (although many generators have a fixed yield size). This breaks the progress counting mechanism in keras_tqdm code. I have detailed my process in the bug report and the pull request. Till the pull request is merged, you can use my fork. Follow these install instructions:

git clone https://github.com/rohitrawat/keras-tqdm.git
cd keras-tqdm
python setup.py install

I hope this post helps people who are running in the Websocket ping error, or in general, are unable to run keras_tqdm with fit_generator.