In previous posts we have reviewed the development of CNN architectures over the last decade and even defined our own custom architecture. In this post we train some models with Keras(Tensorflow). A well designed structure is essential for any Deep Learning project and we start by setting up a simple Tensorflow project to train our cancer detection CNNs that applies some good design.

Project Design

Our project needs to provide at least these minimal functionalities: 1) provide access to data generators, 2) manage the training process, 3) wrap all models in a common interface, 4) provide access to job configuration. In the following sections we quickly walk through the code that allows us to do this.

BaseDataGenerator

BaseDataGenerator is an abstract class that provides configuration and wraps access to data generators

class BaseDataGenerator(object):
    def __init__(self, config):
        self.config = config

    def get_train_data(self):
        raise NotImplementedError

    def get_test_data(self):
        raise NotImplementedError

BaseTrainer

BaseTrainer is an abstract class that wires up the data generators, provides configuration and gives a common entry point to kick-off the training process.

class BaseTrainer():
    def __init__(self, model, train_data_generator, valid_data_generator, config):
        self.model = model
        self.train_data_generator = train_data_generator
        self.valid_data_generator = valid_data_generator
        self.config = config

    def train(self):
        raise NotImplementedError

Model Wrapper

We want to be able to handle both predefined models, supplied by Keras, and our own custom models without knowing the difference in the details. In addition we need to be able to customize the output layers of the predefined models to our own classification needs. We can achieve both of these aims simply by inheriting from Keras Model. Below we see how this is done for a DenseNet169 model and in a previous post we have already seen this for our custom model.

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dropout, GlobalMaxPooling2D, GlobalAveragePooling2D, Flatten, Dense, Concatenate
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SeparableConv2D
from tensorflow.keras.layers import Activation

class DenseNet169Wrapper(tf.keras.Model):

    def __init__(self, classes=1, **kwargs):
        super(DenseNet169Wrapper, self).__init__()      
        self.densnet = tf.keras.applications.DenseNet169(
            input_shape=(96,96,3),
            include_top=False,
            weights=None)
        self.global_max_pooling_2d = GlobalMaxPooling2D()
        self.global_avg_pooling_2d = GlobalAveragePooling2D()
        self.flatten = Flatten()
        self.concatenate = Concatenate(axis=-1)
        self.dropout = Dropout(0.5)
        self.dense = Dense(256, activation='relu')
        self.classifier = Dense(classes, activation='sigmoid')


    def call(self, inputs, **kwargs):
        x = self.densnet(inputs)
        x1 = self.global_max_pooling_2d(x)
        x2 = self.global_avg_pooling_2d(x)
        x3 = self.flatten(x)
        x = self.concatenate([x1,x2,x3])
        x = self.dropout(x)
        x = self.dense(x)
        return self.classifier(x)

Trainer

With both the data generators and models in place it’s time to define our training pipeline. We extend BaseTrainer by selecting from the set of Keras callbacks, optimizers, losses and metrics those which are appropriate to our training and evaluation needs. The configuration of these is supplied by BaseTrainer and we will see configuration handling below.

import os
import numpy as np

from base.base_trainer import BaseTrainer
from tensorflow.keras import optimizers, losses, metrics
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard, CSVLogger
from utils.args import get_args

class Trainer(BaseTrainer):
    def __init__(self, model, train_data_generator, valid_data_generator, config):
        super(Trainer, self).__init__(model, train_data_generator, valid_data_generator, config)
        self.train_steps = self.train_data_generator.n // self.train_data_generator.batch_size
        self.valid_steps = self.valid_data_generator.n // self.valid_data_generator.batch_size
        self.callbacks = []
        self.callbacks.append(
            ModelCheckpoint(
                filepath=os.path.join(self.config.callbacks.checkpoint_dir, '%s.h5' % self.config.exp.name),
                monitor=self.config.callbacks.checkpoint_monitor,
                mode=self.config.callbacks.checkpoint_mode,
                verbose=self.config.callbacks.verbose,
                save_weights_only=self.config.callbacks.checkpoint_save_weights_only,
                save_best_only=self.config.callbacks.checkpoint_save_best_only)
        )
        self.callbacks.append(
            ReduceLROnPlateau(
                monitor=self.config.callbacks.reduce_on_plateau_monitor,
                factor=self.config.callbacks.reduce_on_plateau_factor,
                patience=self.config.callbacks.reduce_on_plateau_patience, 
                verbose=self.config.callbacks.verbose)
        )
        self.callbacks.append(
            TensorBoard(
                log_dir=self.config.callbacks.tensorboard_log_dir,
                write_graph=self.config.callbacks.tensorboard_write_graph)
        )
        self.callbacks.append(
            CSVLogger(
                filename=os.path.join(self.config.callbacks.evaluation_log_dir, '%s.csv' % self.config.exp.name))
        )


    def train(self):
        self.model.compile(
            optimizer=optimizers.Adam(lr=1e-3),
            loss='binary_crossentropy',
            metrics=[
                'accuracy',
                metrics.Precision(),
                metrics.Recall(),
                metrics.AUC()])
        self.model.fit(
            self.train_data_generator,
            epochs=self.config.trainer.epochs,
            steps_per_epoch=self.train_steps,
            validation_data=self.valid_data_generator,
            validation_steps=self.valid_steps,
            callbacks=self.callbacks,
            verbose=1)

Project Configuration

Configuration of each training run is defined with a JSON file that allows many of the training parameters to be tweaked without changing code. For example you can change model-type and choice of optimzer, update optimizer settings, or choose a different learning-rate. Parsing of the JSON file is handled with the following helper function.

import json
from dotmap import DotMap
import os
import time


def get_config_from_json(json_file):
    """
    Get the config from a json file
    :param json_file:
    :return: config(namespace) or config(dictionary)
    """
    # parse the configurations from the config json file provided
    with open(json_file, 'r') as config_file:
        config_dict = json.load(config_file)

    # convert the dictionary to a namespace using bunch lib
    config = DotMap(config_dict)

    return config, config_dict


def process_config(json_file):
    config, _ = get_config_from_json(json_file)
    config.callbacks.tensorboard_log_dir = os.path.join("experiments", time.strftime("%Y-%m-%d/",time.localtime()), config.exp.name, "logs/tensorboard/")
    config.callbacks.evaluation_log_dir = os.path.join("experiments", time.strftime("%Y-%m-%d/",time.localtime()), config.exp.name, "logs/evaluation/")
    config.callbacks.checkpoint_dir = os.path.join("experiments", time.strftime("%Y-%m-%d/",time.localtime()), config.exp.name, "checkpoints/")
    return config

A configuration for a training might look something like this:

{
    "exp": {
      "name": "mobilenet2"
    },
    "model":{
      "name": "mobilenet2",
      "learning_rate": 0.001,
      "optimizer": "adam"
    },
    "trainer":{
      "epochs": 20,
      "batch_size": 64,
      "verbose_training": true
    },
    "callbacks":{
      "checkpoint_monitor": "val_accuracy",
      "checkpoint_mode": "max",
      "checkpoint_save_best_only": true,
      "checkpoint_save_weights_only": false,
      "early_stopping_monitor": "val_loss",
      "early_stopping_patience": 3,
      "early_stopping_restore_best_weights": true,
      "reduce_on_plateau_monitor": "val_loss",
      "reduce_on_plateau_factor": 0.1,
      "reduce_on_plateau_patience": 1,
      "tensorboard_write_graph": false,
      "verbose": true
    }
  }

Execution

Finaly we are ready to run our model training. The code below ties all the parts together in a command-line application. The command-line arguments supply the config file which is used to select the model type and configure the trainer and start the trainer. When training is complete the best model checkpoint is loaded and evaluation data is used to generate performance metrics which are persisted for off-line visualization.

import os
import sys

from data.generator import DataGenerator
from models.mobilenet2 import MobileNet2Wrapper
from models.resnet50 import ResNet50Wrapper
from models.densenet169 import DenseNet169Wrapper
from models.cancernet import CancerNet
from trainers.trainer import Trainer
from utils.config import process_config
from utils.dirs import create_dirs
from utils.args import get_args

from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf 

def main():

    try:
        args = get_args()
        config = process_config(args.config)
    except:
        print("missing or invalid arguments")
        exit(0)

    create_dirs([
        config.callbacks.tensorboard_log_dir,
        config.callbacks.evaluation_log_dir,
        config.callbacks.checkpoint_dir])

    print('Create the data generator.')
    data = DataGenerator(config)

    models = {
        "cancernet": CancerNet(),
        "densenet169": DenseNet169Wrapper(),
        "mobilenet2": MobileNet2Wrapper(),
        "resnet50": ResNet50Wrapper(),
    }

    model = models.get(config.model.name, "ERROR: invalid model name")

    print("Train classifier:",  model.__class__)
    trainer = Trainer(model, data.train_data_generator, data.valid_data_generator, config)
    trainer.train()

    model.load_weights(
        filepath=os.path.join(config.callbacks.checkpoint_dir, '%s.h5' % config.exp.name))

    loss, accuracy, precision, recall, auc = model.evaluate(
        data.test_data_generator)
    
    predictions = model.predict(
        data.test_data_generator,
        steps=data.test_data_generator.n,
        verbose=1)

    fpr, tpr, thresholds = roc_curve(
        data.test_data_generator.classes,
        predictions)

    np.savez(
        file=os.path.join(config.callbacks.evaluation_log_dir, '%s-metrics.npz' % config.exp.name), 
        loss=loss, accuracy=accuracy, precision=precision, recall=recall, auc=auc,
        fpr=fpr, tpr=tpr, thresholds=thresholds)


if __name__ == '__main__':
    main()