Keras comes with a long list of predefined callbacks that are ready to use. Keras callbacks are functions that are executed during the training process.
According to Keras Documentation, A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training. You can pass a list of callbacks (as the keyword argument callbacks) to the .fit() method of the Sequential or Model classes. The appropriate methods of the callbacks will then be called at each stage of the training.
In this article, we’ll discuss how to create custom callbacks in Keras.
KERAS CUSTOM CALLBACK:
Along with various built-in callbacks that are ready to use, we can also create custom callbacks in Keras.
In Keras, we can create custom callbacks in two ways
- By creating a class that inherits from keras.callbacks.Callback
- By using the LambdaCallback
USING CALLBACK CLASS:
Now let’s see how to create a custom callback class that computes the ROC AUC at the end of every epoch.
To create a custom callback, we need to create a class that inherits from keras.callbacks.Callback and redefining the methods we need.
Since we are calculating ROC AUC at the end of each epoch we’ll override the method on_epoch_end
First, let’s create a CNN model that classifies the fashion_mnist data
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
#https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py from keras.models import Sequential from keras.layers import Dense, Dropout, Flatten from keras.layers import Conv2D, MaxPooling2D from keras.datasets import fashion_mnist batch_size = 128 num_classes = 10 epochs = 50 # input image dimensions img_rows, img_cols = 28, 28 # the data, split between train and test sets (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data() if K.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices y_train = keras.utils.to_categorical(y_train, num_classes) y_test = keras.utils.to_categorical(y_test, num_classes) #Building our CNN model = Sequential() model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(num_classes, activation='softmax')) #compile the model model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adadelta(), metrics=['accuracy']) |
Next, we can create our custom callback that inherits from the class keras.callbacks
1 2 3 4 5 6 7 8 9 10 |
from keras import callbacks from sklearn.metrics import roc_auc_score #creating a custom callback class class ROCCallback(callbacks.Callback): def on_epoch_end(self, _, logs): y_true = self.validation_data[1] y_pred = self.model.predict(self.validation_data[0]) auc = roc_auc_score(y_true, y_pred) print('AUC:', auc) |
Now we can pass the callbacks object into the model fit method as a list
1 2 3 4 5 6 7 8 9 |
#instantiate the custom callback class auc = ROCCallback() history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_split=0.3, callbacks=[auc]) |
LAMBDA CALLBACK:
Without creating a custom class we can also create a simple custom callback on-the-fly using the LambdaCallback.
The signature of the lambda callback is as follows.
1 |
LambdaCallback(on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None, on_batch_end=None, on_train_begin=None, on_train_end=None) |
Now let’s create a custom lambda callback which stops the training of the model once the validation accuracy hits 90%
1 2 3 4 5 6 |
from keras.callbacks import LambdaCallback def on_epoch_end(_,logs): THRESHOLD = 0.90 if(logs['val_acc']> THRESHOLD): model.stop_training=True lambdac = LambdaCallback(on_epoch_end=on_epoch_end) |
Below is the complete code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import keras from keras.models import Sequential from keras.layers import Dense, Dropout, Flatten from keras.layers import Conv2D, MaxPooling2D from keras import backend as K from keras.datasets import fashion_mnist from keras.callbacks import LambdaCallback batch_size = 128 num_classes = 10 epochs = 12 # input image dimensions img_rows, img_cols = 28, 28 # the data, split between train and test sets (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data() if K.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices y_train = keras.utils.to_categorical(y_train, num_classes) y_test = keras.utils.to_categorical(y_test, num_classes) #Building our CNN model = Sequential() model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(num_classes, activation='softmax')) #compile the model model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adadelta(), metrics=['accuracy']) #reating a function that will be called at the end of the epoch def on_epoch_end(_,logs): THRESHOLD = 0.90 if(logs['val_acc']> THRESHOLD): model.stop_training=True print('Stopping the training. Validation accuracy reaches 90%') lambdac = LambdaCallback(on_epoch_end=on_epoch_end) history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_split=0.3, callbacks=[lambdac]) |